class BasePeerPool(Service, AsyncIterable[BasePeer]): """ PeerPool maintains connections to up-to max_peers on a given network. """ _report_stats_interval = 60 _report_metrics_interval = 15 # for influxdb/grafana reporting _peer_boot_timeout = DEFAULT_PEER_BOOT_TIMEOUT _event_bus: EndpointAPI = None _handshake_locks: ResourceLock[NodeAPI] peer_reporter_registry_class: Type[ PeerReporterRegistry[Any]] = PeerReporterRegistry[BasePeer] def __init__( self, privkey: datatypes.PrivateKey, context: BasePeerContext, max_peers: int = DEFAULT_MAX_PEERS, event_bus: EndpointAPI = None, metrics_registry: MetricsRegistry = None, ) -> None: self.logger = get_logger(self.__module__ + '.' + self.__class__.__name__) self.privkey = privkey self.max_peers = max_peers self.context = context self.connected_nodes: Dict[SessionAPI, BasePeer] = {} self._subscribers: List[PeerSubscriber] = [] self._event_bus = event_bus if metrics_registry is None: # Initialize with a MetricsRegistry from pyformance as p2p can not depend on Trinity # This is so that we don't need to pass a MetricsRegistry in tests and mocked pools. metrics_registry = MetricsRegistry() self._active_peer_counter = metrics_registry.counter( 'trinity.p2p/peers.counter') self._peer_reporter_registry = self.get_peer_reporter_registry( metrics_registry) # Restricts the number of concurrent connection attempts can be made self._connection_attempt_lock = asyncio.BoundedSemaphore( MAX_CONCURRENT_CONNECTION_ATTEMPTS) # Ensure we can only have a single concurrent handshake in flight per remote self._handshake_locks = ResourceLock() self.peer_backends = self.setup_peer_backends() self.connection_tracker = self.setup_connection_tracker() @cached_property def has_event_bus(self) -> bool: return self._event_bus is not None def get_event_bus(self) -> EndpointAPI: if self._event_bus is None: raise AttributeError("No event bus configured for this peer pool") return self._event_bus def setup_connection_tracker(self) -> BaseConnectionTracker: """ Return an instance of `p2p.tracking.connection.BaseConnectionTracker` which will be used to track peer connection failures. """ return NoopConnectionTracker() def setup_peer_backends(self) -> Tuple[BasePeerBackend, ...]: if self.has_event_bus: return ( DiscoveryPeerBackend(self.get_event_bus()), BootnodesPeerBackend(self.get_event_bus()), ) else: self.logger.warning("No event bus configured for peer pool.") return () @property def available_slots(self) -> int: return self.max_peers - len(self) async def _add_peers_from_backend( self, backend: BasePeerBackend, should_skip_fn: Callable[[Tuple[NodeID, ...], NodeAPI], bool]) -> int: connected_node_ids = { peer.remote.id for peer in self.connected_nodes.values() } # Only ask for random bootnodes if we're not connected to any peers. if isinstance(backend, BootnodesPeerBackend) and len(connected_node_ids) > 0: return 0 try: blacklisted_node_ids = await asyncio.wait_for( self.connection_tracker.get_blacklisted(), timeout=1) except asyncio.TimeoutError: self.logger.warning( "Timed out getting blacklisted peers from connection tracker, pausing peer " "addition until we can get that.") return 0 skip_list = connected_node_ids.union(blacklisted_node_ids) should_skip_fn = functools.partial(should_skip_fn, skip_list) # Request a large batch on every iteration as that will effectively push DiscoveryService # to trigger new peer lookups in order to find enough compatible peers to fulfill our # request. There's probably some room for experimentation here in order to find an optimal # value. max_candidates = self.available_slots * OVER_PROVISION_MISSING_PEERS try: candidates = await asyncio.wait_for( backend.get_peer_candidates( max_candidates=max_candidates, should_skip_fn=should_skip_fn, ), timeout=REQUEST_PEER_CANDIDATE_TIMEOUT, ) except asyncio.TimeoutError: self.logger.warning("PeerCandidateRequest timed out to backend %s", backend) return 0 else: self.logger.debug("Got %d peer candidates from backend %s", len(candidates), backend) if candidates: await self.connect_to_nodes(candidates) return len(candidates) def __len__(self) -> int: return len(self.connected_nodes) @property @abstractmethod def peer_factory_class(self) -> Type[BasePeerFactory]: ... def get_peer_factory(self) -> BasePeerFactory: return self.peer_factory_class( privkey=self.privkey, context=self.context, event_bus=self._event_bus, ) def get_peer_reporter_registry( self, metrics_registry: MetricsRegistry ) -> PeerReporterRegistry[BasePeer]: return self.peer_reporter_registry_class(metrics_registry) async def get_protocol_capabilities(self) -> Tuple[Tuple[str, int], ...]: return tuple( (handshaker.protocol_class.name, handshaker.protocol_class.version) for handshaker in await self.get_peer_factory().get_handshakers()) @property def is_full(self) -> bool: return len(self) >= self.max_peers def is_valid_connection_candidate(self, candidate: NodeAPI) -> bool: # connect to no more then 2 nodes with the same IP nodes_by_ip = groupby( operator.attrgetter('remote.address.ip'), self.connected_nodes.keys(), ) matching_ip_nodes = nodes_by_ip.get(candidate.address.ip, []) return len(matching_ip_nodes) <= 2 def subscribe(self, subscriber: PeerSubscriber) -> None: self._subscribers.append(subscriber) for peer in self.connected_nodes.values(): subscriber.register_peer(peer) peer.add_subscriber(subscriber) def unsubscribe(self, subscriber: PeerSubscriber) -> None: if subscriber in self._subscribers: self._subscribers.remove(subscriber) for peer in self.connected_nodes.values(): peer.remove_subscriber(subscriber) async def start_peer(self, peer: BasePeer) -> None: self.manager.run_child_service(peer.connection) await asyncio.wait_for(peer.connection.get_manager().wait_started(), timeout=PEER_READY_TIMEOUT) try: await peer.connection.run_peer(peer) except asyncio.TimeoutError as err: self.logger.debug('Timout waiting for peer to start: %s', err) return if peer.get_manager().is_running: self._add_peer(peer, ()) else: self.logger.debug( "%s was cancelled immediately, not adding to pool", peer) try: await asyncio.wait_for( peer.boot_manager.get_manager().wait_finished(), timeout=self._peer_boot_timeout) except asyncio.TimeoutError as err: self.logger.debug('Timout waiting for peer to boot: %s', err) await peer.disconnect(DisconnectReason.TIMEOUT) return else: if not peer.get_manager().is_running: self.logger.debug( '%s disconnected during boot-up, dropped from pool', peer) def _add_peer(self, peer: BasePeer, msgs: Tuple[PeerMessage, ...]) -> None: """Add the given peer to the pool. Appart from adding it to our list of connected nodes and adding each of our subscriber's to the peer, we also add the given messages to our subscriber's queues. """ if len(self) < QUIET_PEER_POOL_SIZE: logger = self.logger.info else: logger = self.logger.debug logger("Adding %s to pool", peer) self.connected_nodes[peer.session] = peer self._active_peer_counter.inc() self._peer_reporter_registry.assign_peer_reporter(peer) peer.add_finished_callback(self._peer_finished) for subscriber in self._subscribers: subscriber.register_peer(peer) peer.add_subscriber(subscriber) for msg in msgs: subscriber.add_msg(msg) @abstractmethod async def maybe_connect_more_peers(self) -> None: ... async def run(self) -> None: if self.has_event_bus: self.manager.run_daemon_task(self.maybe_connect_more_peers) self.manager.run_daemon_task(self._periodically_report_stats) self.manager.run_daemon_task(self._periodically_report_metrics) await self.manager.wait_finished() async def connect(self, remote: NodeAPI) -> BasePeer: """ Connect to the given remote and return a Peer instance when successful. Returns None if the remote is unreachable, times out or is useless. """ if not self._handshake_locks.is_locked(remote): self.logger.warning( "Tried to connect to %s without acquiring lock first!", remote) if any(peer.remote == remote for peer in self.connected_nodes.values()): self.logger.warning( "Attempted to connect to peer we are already connected to: %s", remote) raise IneligiblePeer(f"Already connected to {remote}") try: self.logger.debug("Connecting to %s...", remote) task = asyncio.ensure_future( self.get_peer_factory().handshake(remote)) return await asyncio.wait_for(task, timeout=HANDSHAKE_TIMEOUT) except BadAckMessage: # This is kept separate from the # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't # silencing an error in our authentication code. self.logger.error('Got bad auth ack from %r', remote) # dump the full stacktrace in the debug logs self.logger.debug('Got bad auth ack from %r', remote, exc_info=True) raise except MalformedMessage as e: # This is kept separate from the # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't # silencing an error in how we decode messages during handshake. self.logger.error( 'Got malformed response from %r during handshake', remote) # dump the full stacktrace in the debug logs self.logger.debug('Got malformed response from %r', remote, exc_info=True) self.connection_tracker.record_failure(remote, e) raise except HandshakeFailure as e: self.logger.debug("Could not complete handshake with %r: %s", remote, repr(e)) self.connection_tracker.record_failure(remote, e) raise except COMMON_PEER_CONNECTION_EXCEPTIONS as e: self.logger.debug("Could not complete handshake with %r: %s", remote, repr(e)) self.connection_tracker.record_failure(remote, e) raise finally: # XXX: We sometimes get an exception here but the task is finished and with # an exception as well. No idea how that happens but if we don't consume the # task's exception, asyncio complains. if not task.done(): task.cancel() try: await task except (Exception, asyncio.CancelledError): pass async def connect_to_nodes(self, nodes: Sequence[NodeAPI]) -> None: # create an generator for the nodes nodes_iter = iter(nodes) while True: if self.is_full or not self.manager.is_running: return # only attempt to connect to up to the maximum number of available # peer slots that are open. batch_size = clamp(1, 10, self.available_slots) batch = tuple(take(batch_size, nodes_iter)) # There are no more *known* nodes to connect to. if not batch: return self.logger.debug( 'Initiating %d peer connection attempts with %d open peer slots', len(batch), self.available_slots, ) # Try to connect to the peers concurrently. await asyncio.gather(*(self.connect_to_node(node) for node in batch)) def lock_node_for_handshake(self, node: NodeAPI) -> AsyncContextManager[None]: return self._handshake_locks.lock(node) def is_connected_to_node(self, node: NodeAPI) -> bool: return any(node == session.remote for session in self.connected_nodes.keys()) async def connect_to_node(self, node: NodeAPI) -> None: """ Connect to a single node quietly aborting if the peer pool is full or shutting down, or one of the expected peer level exceptions is raised while connecting. """ if self.is_full or not self.manager.is_running: self.logger.warning( "Asked to connect to node when either full or not operational") return if self._handshake_locks.is_locked(node): self.logger.info( "Asked to connect to node when handshake lock is already locked, will wait" ) async with self.lock_node_for_handshake(node): if self.is_connected_to_node(node): self.logger.debug( "Aborting outbound connection attempt to %s. Already connected!", node, ) return try: async with self._connection_attempt_lock: peer = await self.connect(node) except ALLOWED_PEER_CONNECTION_EXCEPTIONS: return # Check again to see if we have *become* full since the previous # check. if self.is_full: self.logger.debug( "Successfully connected to %s but peer pool is full. Disconnecting.", peer, ) await peer.disconnect(DisconnectReason.TOO_MANY_PEERS) return elif not self.manager.is_running: self.logger.debug( "Successfully connected to %s but peer pool is closing. Disconnecting.", peer, ) await peer.disconnect(DisconnectReason.CLIENT_QUITTING) return else: await self.start_peer(peer) def _peer_finished(self, peer: BasePeer) -> None: """ Remove the given peer from our list of connected nodes. This is passed as a callback to be called when a peer finishes. """ if peer.session in self.connected_nodes: self.logger.debug( "Removing %s from pool: local_reason=%s remote_reason=%s", peer, peer.local_disconnect_reason, peer.remote_disconnect_reason, ) self.connected_nodes.pop(peer.session) else: self.logger.warning( "%s finished but was not found in connected_nodes (%s)", peer, tuple(self.connected_nodes.values()), ) for subscriber in self._subscribers: subscriber.deregister_peer(peer) self._active_peer_counter.dec() self._peer_reporter_registry.unassign_peer_reporter(peer) async def __aiter__(self) -> AsyncIterator[BasePeer]: for peer in tuple(self.connected_nodes.values()): # Yield control to ensure we process any disconnection requests from peers. Otherwise # we could return peers that should have been disconnected already. await asyncio.sleep(0) if peer.get_manager().is_running and not peer.is_closing: yield peer async def _periodically_report_metrics(self) -> None: while self.manager.is_running: self._peer_reporter_registry.trigger_peer_reports() await asyncio.sleep(self._report_metrics_interval) async def _periodically_report_stats(self) -> None: while self.manager.is_running: inbound_peers = len([ peer for peer in self.connected_nodes.values() if peer.inbound ]) self.logger.info("Connected peers: %d inbound, %d outbound", inbound_peers, (len(self.connected_nodes) - inbound_peers)) subscribers = len(self._subscribers) if subscribers: longest_queue = max(self._subscribers, key=operator.attrgetter('queue_size')) self.logger.debug( "Peer subscribers: %d, longest queue: %s(%d)", subscribers, longest_queue.__class__.__name__, longest_queue.queue_size) self.logger.debug("== Peer details == ") # make a copy, because we might edit the original during iteration peers = tuple(self.connected_nodes.values()) for peer in peers: if not peer.get_manager().is_running: self.logger.debug( "%s is no longer alive but had not been removed from pool", peer) continue self.logger.debug( "%s: uptime=%s, received_msgs=%d", peer, humanize_seconds(peer.uptime), peer.received_msgs_count, ) self.logger.debug( "client_version_string='%s'", peer.safe_client_version_string, ) try: for line in peer.get_extra_stats(): self.logger.debug(" %s", line) except (UnknownAPI, PeerConnectionLost) as exc: self.logger.debug(" Failure during stats lookup: %r", exc) self.logger.debug("== End peer details == ") await asyncio.sleep(self._report_stats_interval)
class Multiplexer(CancellableMixin, MultiplexerAPI): logger = cast(ExtendedDebugLogger, logging.getLogger('p2p.multiplexer.Multiplexer')) _multiplex_token: CancelToken _transport: TransportAPI _msg_counts: DefaultDict[Type[CommandAPI], int] _protocol_locks: ResourceLock _protocol_queues: Dict[Type[ProtocolAPI], 'asyncio.Queue[Tuple[CommandAPI, Payload]]'] def __init__(self, transport: TransportAPI, base_protocol: P2PProtocol, protocols: Sequence[ProtocolAPI], token: CancelToken = None, max_queue_size: int = 4096) -> None: if token is None: loop = None else: loop = token.loop base_token = CancelToken(f'multiplexer[{transport.remote}]', loop=loop) if token is None: self.cancel_token = base_token else: self.cancel_token = base_token.chain(token) self._transport = transport # the base `p2p` protocol instance. self._base_protocol = base_protocol # the sub-protocol instances self._protocols = protocols # Lock to ensure that multiple call sites cannot concurrently stream # messages. self._multiplex_lock = asyncio.Lock() # Lock management on a per-protocol basis to ensure we only have one # stream consumer for each protocol. self._protocol_locks = ResourceLock() # Each protocol gets a queue where messages for the individual protocol # are placed when streamed from the transport self._protocol_queues = { type(protocol): asyncio.Queue(max_queue_size) for protocol in self.get_protocols() } self._msg_counts = collections.defaultdict(int) def __str__(self) -> str: protocol_infos = ','.join( tuple(f"{proto.name}:{proto.version}" for proto in self.get_protocols())) return f"Multiplexer[{protocol_infos}]" def __repr__(self) -> str: return f"<{self}>" # # Transport API # def get_transport(self) -> TransportAPI: return self._transport # # Message Counts # def get_total_msg_count(self) -> int: return sum(self._msg_counts.values()) # # Proxy Transport methods # @property def remote(self) -> NodeAPI: return self._transport.remote @property def is_closing(self) -> bool: return self._transport.is_closing def close(self) -> None: self._transport.close() self.cancel_token.trigger() # # Protocol API # def has_protocol( self, protocol_identifier: Union[ProtocolAPI, Type[ProtocolAPI]]) -> bool: try: if isinstance(protocol_identifier, Protocol): self.get_protocol_by_type(type(protocol_identifier)) return True elif isinstance(protocol_identifier, type): self.get_protocol_by_type(protocol_identifier) return True else: raise TypeError( f"Unsupported protocol value: {protocol_identifier} of type " f"{type(protocol_identifier)}") except UnknownProtocol: return False def get_protocol_by_type(self, protocol_class: Type[TProtocol]) -> TProtocol: if protocol_class is P2PProtocol: return cast(TProtocol, self._base_protocol) for protocol in self._protocols: if type(protocol) is protocol_class: return cast(TProtocol, protocol) raise UnknownProtocol(f"No protocol found with type {protocol_class}") def get_base_protocol(self) -> P2PProtocol: return self._base_protocol def get_protocols(self) -> Tuple[ProtocolAPI, ...]: return tuple(cons(self._base_protocol, self._protocols)) # # Streaming API # async def stream_protocol_messages( self, protocol_identifier: Union[ProtocolAPI, Type[ProtocolAPI]], ) -> AsyncIterator[Tuple[CommandAPI, Payload]]: """ Stream the messages for the specified protocol. """ if isinstance(protocol_identifier, Protocol): protocol_class = type(protocol_identifier) elif isinstance(protocol_identifier, type) and issubclass( protocol_identifier, Protocol): protocol_class = protocol_identifier else: raise TypeError("Unknown protocol identifier: {protocol}") if not self.has_protocol(protocol_class): raise UnknownProtocol(f"Unknown protocol '{protocol_class}'") if self._protocol_locks.is_locked(protocol_class): raise Exception( f"Streaming lock for {protocol_class} is not free.") async with self._protocol_locks.lock(protocol_class): msg_queue = self._protocol_queues[protocol_class] if not hasattr(self, '_multiplex_token'): raise Exception("Multiplexer is not multiplexed") token = self._multiplex_token while not self.is_closing and not token.triggered: try: # We use an optimistic strategy here of using # `get_nowait()` to reduce the number of times we yield to # the event loop. Since this is an async generator it will # yield to the loop each time it returns a value so we # don't have to worry about this blocking other processes. yield msg_queue.get_nowait() except asyncio.QueueEmpty: yield await self.wait(msg_queue.get(), token=token) # # Message reading and streaming API # @asynccontextmanager async def multiplex(self) -> AsyncIterator[None]: """ API for running the background task that feeds individual protocol queues that allows each individual protocol to stream only its own messages. """ # We generate a new token for each time the multiplexer is used to # multiplex so that we can reliably cancel it without requiring the # master token for the multiplexer to be cancelled. async with self._multiplex_lock: multiplex_token = CancelToken( 'multiplex', loop=self.cancel_token.loop, ).chain(self.cancel_token) self._multiplex_token = multiplex_token fut = asyncio.ensure_future(self._do_multiplexing(multiplex_token)) try: yield finally: multiplex_token.trigger() del self._multiplex_token if fut.done(): fut.result() else: fut.cancel() try: await fut except asyncio.CancelledError: pass async def _do_multiplexing(self, token: CancelToken) -> None: """ Background task that reads messages from the transport and feeds them into individual queues for each of the protocols. """ msg_stream = self.wait_iter(stream_transport_messages( self._transport, self._base_protocol, *self._protocols, token=token, ), token=token) async for protocol, cmd, msg in msg_stream: # track total number of messages received for each command type. self._msg_counts[type(cmd)] += 1 queue = self._protocol_queues[type(protocol)] try: # We must use `put_nowait` here to ensure that in the event # that a single protocol queue is full that we don't block # other protocol messages getting through. queue.put_nowait((cmd, msg)) except asyncio.QueueFull: self.logger.error( ("Multiplexing queue for protocol '%s' full. " "discarding message: %s"), protocol, cmd, )
class BasePeerPool(BaseService, AsyncIterable[BasePeer]): """ PeerPool maintains connections to up-to max_peers on a given network. """ _report_interval = 60 _peer_boot_timeout = DEFAULT_PEER_BOOT_TIMEOUT _event_bus: EndpointAPI = None def __init__( self, privkey: datatypes.PrivateKey, context: BasePeerContext, max_peers: int = DEFAULT_MAX_PEERS, token: CancelToken = None, event_bus: EndpointAPI = None, ) -> None: super().__init__(token) self.privkey = privkey self.max_peers = max_peers self.context = context self.connected_nodes: Dict[NodeAPI, BasePeer] = {} self._subscribers: List[PeerSubscriber] = [] self._event_bus = event_bus # Restricts the number of concurrent connection attempts can be made self._connection_attempt_lock = asyncio.BoundedSemaphore( MAX_CONCURRENT_CONNECTION_ATTEMPTS) # Ensure we can only have a single concurrent handshake in flight per remote self._handshake_locks = ResourceLock() self.peer_backends = self.setup_peer_backends() self.connection_tracker = self.setup_connection_tracker() @property def has_event_bus(self) -> bool: return self._event_bus is not None def get_event_bus(self) -> EndpointAPI: if self._event_bus is None: raise AttributeError("No event bus configured for this peer pool") return self._event_bus def setup_connection_tracker(self) -> BaseConnectionTracker: """ Return an instance of `p2p.tracking.connection.BaseConnectionTracker` which will be used to track peer connection failures. """ return NoopConnectionTracker() def setup_peer_backends(self) -> Tuple[BasePeerBackend, ...]: if self.has_event_bus: return ( DiscoveryPeerBackend(self.get_event_bus()), BootnodesPeerBackend(self.get_event_bus()), ) else: self.logger.warning("No event bus configured for peer pool.") return () async def _add_peers_from_backend(self, backend: BasePeerBackend) -> None: available_slots = self.max_peers - len(self) try: connected_remotes = { peer.remote for peer in self.connected_nodes.values() } candidates = await self.wait( backend.get_peer_candidates( num_requested=available_slots, connected_remotes=connected_remotes, ), timeout=REQUEST_PEER_CANDIDATE_TIMEOUT, ) except TimeoutError: self.logger.warning("PeerCandidateRequest timed out to backend %s", backend) return else: self.logger.debug2( "Got candidates from backend %s (%s)", backend, candidates, ) if candidates: await self.connect_to_nodes(iter(candidates)) async def maybe_connect_more_peers(self) -> None: rate_limiter = TokenBucket( rate=1 / PEER_CONNECT_INTERVAL, capacity=MAX_SEQUENTIAL_PEER_CONNECT, ) while self.is_operational: if self.is_full: await self.sleep(PEER_CONNECT_INTERVAL) continue await self.wait(rate_limiter.take()) try: await self.wait( asyncio.gather(*(self._add_peers_from_backend(backend) for backend in self.peer_backends))) except OperationCancelled: break def __len__(self) -> int: return len(self.connected_nodes) @property @abstractmethod def peer_factory_class(self) -> Type[BasePeerFactory]: ... def get_peer_factory(self) -> BasePeerFactory: return self.peer_factory_class( privkey=self.privkey, context=self.context, event_bus=self._event_bus, token=self.cancel_token, ) @property def is_full(self) -> bool: return len(self) >= self.max_peers def is_valid_connection_candidate(self, candidate: NodeAPI) -> bool: # connect to no more then 2 nodes with the same IP nodes_by_ip = groupby( operator.attrgetter('address.ip'), self.connected_nodes.keys(), ) matching_ip_nodes = nodes_by_ip.get(candidate.address.ip, []) return len(matching_ip_nodes) <= 2 def subscribe(self, subscriber: PeerSubscriber) -> None: self._subscribers.append(subscriber) for peer in self.connected_nodes.values(): subscriber.register_peer(peer) peer.add_subscriber(subscriber) def unsubscribe(self, subscriber: PeerSubscriber) -> None: if subscriber in self._subscribers: self._subscribers.remove(subscriber) for peer in self.connected_nodes.values(): peer.remove_subscriber(subscriber) async def start_peer(self, peer: BasePeer) -> None: self.run_child_service(peer.connection) await self.wait(peer.connection.events.started.wait(), timeout=1) self.run_child_service(peer) await self.wait(peer.events.started.wait(), timeout=1) if peer.is_operational: self._add_peer(peer, ()) else: self.logger.debug( "%s was cancelled immediately, not adding to pool", peer) try: await self.wait(peer.boot_manager.events.finished.wait(), timeout=self._peer_boot_timeout) except TimeoutError as err: self.logger.debug('Timout waiting for peer to boot: %s', err) await peer.disconnect(DisconnectReason.timeout) return except HandshakeFailure as err: self.connection_tracker.record_failure(peer.remote, err) raise else: if not peer.is_operational: self.logger.debug( '%s disconnected during boot-up, dropped from pool', peer) def _add_peer(self, peer: BasePeer, msgs: Tuple[PeerMessage, ...]) -> None: """Add the given peer to the pool. Appart from adding it to our list of connected nodes and adding each of our subscriber's to the peer, we also add the given messages to our subscriber's queues. """ self.logger.info('Adding %s to pool', peer) self.connected_nodes[peer.remote] = peer peer.add_finished_callback(self._peer_finished) for subscriber in self._subscribers: subscriber.register_peer(peer) peer.add_subscriber(subscriber) for msg in msgs: subscriber.add_msg(msg) async def _run(self) -> None: # FIXME: PeerPool should probably no longer be a BaseService, but for now we're keeping it # so in order to ensure we cancel all peers when we terminate. if self.has_event_bus: self.run_daemon_task(self.maybe_connect_more_peers()) self.run_daemon_task(self._periodically_report_stats()) await self.cancel_token.wait() async def stop_all_peers(self) -> None: self.logger.info("Stopping all peers ...") peers = self.connected_nodes.values() disconnections = (peer.disconnect(DisconnectReason.client_quitting) for peer in peers if peer.is_running) await asyncio.gather(*disconnections) async def _cleanup(self) -> None: await self.stop_all_peers() async def connect(self, remote: NodeAPI) -> BasePeer: """ Connect to the given remote and return a Peer instance when successful. Returns None if the remote is unreachable, times out or is useless. """ if self._handshake_locks.is_locked(remote): self.logger.debug2("Skipping %s; already shaking hands", remote) raise IneligiblePeer(f"Already shaking hands with {remote}") async with self._handshake_locks.lock(remote): if remote in self.connected_nodes: self.logger.debug2("Skipping %s; already connected to it", remote) raise IneligiblePeer(f"Already connected to {remote}") try: should_connect = await self.wait( self.connection_tracker.should_connect_to(remote), timeout=1, ) except TimeoutError: self.logger.warning( "ConnectionTracker.should_connect_to request timed out.") raise if not should_connect: raise IneligiblePeer( f"Peer database rejected peer candidate: {remote}") try: self.logger.debug2("Connecting to %s...", remote) peer = await self.wait( self.get_peer_factory().handshake(remote)) return peer except OperationCancelled: # Pass it on to instruct our main loop to stop. raise except BadAckMessage: # This is kept separate from the # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't # silencing an error in our authentication code. self.logger.error('Got bad auth ack from %r', remote) # dump the full stacktrace in the debug logs self.logger.debug('Got bad auth ack from %r', remote, exc_info=True) raise except MalformedMessage: # This is kept separate from the # `COMMON_PEER_CONNECTION_EXCEPTIONS` to be sure that we aren't # silencing an error in how we decode messages during handshake. self.logger.error( 'Got malformed response from %r during handshake', remote) # dump the full stacktrace in the debug logs self.logger.debug('Got malformed response from %r', remote, exc_info=True) raise except HandshakeFailure as e: self.logger.debug("Could not complete handshake with %r: %s", remote, repr(e)) self.connection_tracker.record_failure(remote, e) raise except COMMON_PEER_CONNECTION_EXCEPTIONS as e: self.logger.debug("Could not complete handshake with %r: %s", remote, repr(e)) raise except Exception: self.logger.exception( "Unexpected error during auth/p2p handshake with %r", remote) raise async def connect_to_nodes(self, nodes: Iterator[NodeAPI]) -> None: # create an generator for the nodes nodes_iter = iter(nodes) while True: if self.is_full or not self.is_operational: return # only attempt to connect to up to the maximum number of available # peer slots that are open. available_peer_slots = self.max_peers - len(self) batch_size = clamp(1, 10, available_peer_slots) batch = tuple(take(batch_size, nodes_iter)) # There are no more *known* nodes to connect to. if not batch: return self.logger.debug( 'Initiating %d peer connection attempts with %d open peer slots', len(batch), available_peer_slots, ) # Try to connect to the peers concurrently. await self.wait( asyncio.gather( *(self.connect_to_node(node) for node in batch), loop=self.get_event_loop(), )) async def connect_to_node(self, node: NodeAPI) -> None: """ Connect to a single node quietly aborting if the peer pool is full or shutting down, or one of the expected peer level exceptions is raised while connecting. """ if self.is_full or not self.is_operational: return try: async with self._connection_attempt_lock: peer = await self.connect(node) except ALLOWED_PEER_CONNECTION_EXCEPTIONS: return # Check again to see if we have *become* full since the previous # check. if self.is_full: self.logger.debug( "Successfully connected to %s but peer pool is full. Disconnecting.", peer, ) await peer.disconnect(DisconnectReason.too_many_peers) return elif not self.is_operational: self.logger.debug( "Successfully connected to %s but peer pool no longer operational. Disconnecting.", peer, ) await peer.disconnect(DisconnectReason.client_quitting) return else: await self.start_peer(peer) def _peer_finished(self, peer: AsyncioServiceAPI) -> None: """ Remove the given peer from our list of connected nodes. This is passed as a callback to be called when a peer finishes. """ peer = cast(BasePeer, peer) if peer.remote in self.connected_nodes: self.logger.info("%s finished[%s], removing from pool", peer, peer.disconnect_reason) self.connected_nodes.pop(peer.remote) else: self.logger.warning( "%s finished but was not found in connected_nodes (%s)", peer, tuple(sorted(self.connected_nodes.values())), ) for subscriber in self._subscribers: subscriber.deregister_peer(peer) async def __aiter__(self) -> AsyncIterator[BasePeer]: for peer in tuple(self.connected_nodes.values()): # Yield control to ensure we process any disconnection requests from peers. Otherwise # we could return peers that should have been disconnected already. await asyncio.sleep(0) if peer.is_operational and not peer.is_closing: yield peer async def _periodically_report_stats(self) -> None: while self.is_operational: inbound_peers = len([ peer for peer in self.connected_nodes.values() if peer.inbound ]) self.logger.info("Connected peers: %d inbound, %d outbound", inbound_peers, (len(self.connected_nodes) - inbound_peers)) subscribers = len(self._subscribers) if subscribers: longest_queue = max(self._subscribers, key=operator.attrgetter('queue_size')) self.logger.debug( "Peer subscribers: %d, longest queue: %s(%d)", subscribers, longest_queue.__class__.__name__, longest_queue.queue_size) self.logger.debug("== Peer details == ") for peer in self.connected_nodes.values(): if not peer.is_running: self.logger.warning( "%s is no longer alive but has not been removed from pool", peer) continue self.logger.debug( "%s: uptime=%s, received_msgs=%d", peer, humanize_seconds(peer.uptime), peer.received_msgs_count, ) self.logger.debug("client_version_string='%s'", peer.client_version_string) for line in peer.get_extra_stats(): self.logger.debug(" %s", line) self.logger.debug("== End peer details == ") await self.sleep(self._report_interval)
class Multiplexer(CancellableMixin, MultiplexerAPI): logger = get_extended_debug_logger('p2p.multiplexer.Multiplexer') _multiplex_token: CancelToken _transport: TransportAPI _msg_counts: DefaultDict[Type[CommandAPI], int] _protocol_locks: ResourceLock _protocol_queues: Dict[Type[ProtocolAPI], 'asyncio.Queue[Tuple[CommandAPI, Payload]]'] def __init__(self, transport: TransportAPI, base_protocol: BaseP2PProtocol, protocols: Sequence[ProtocolAPI], token: CancelToken = None, max_queue_size: int = 4096) -> None: if token is None: loop = None else: loop = token.loop base_token = CancelToken(f'multiplexer[{transport.remote}]', loop=loop) if token is None: self.cancel_token = base_token else: self.cancel_token = base_token.chain(token) self._transport = transport # the base `p2p` protocol instance. self._base_protocol = base_protocol # the sub-protocol instances self._protocols = protocols # Lock to ensure that multiple call sites cannot concurrently stream # messages. self._multiplex_lock = asyncio.Lock() # Lock management on a per-protocol basis to ensure we only have one # stream consumer for each protocol. self._protocol_locks = ResourceLock() # Each protocol gets a queue where messages for the individual protocol # are placed when streamed from the transport self._protocol_queues = { type(protocol): asyncio.Queue(max_queue_size) for protocol in self.get_protocols() } self._msg_counts = collections.defaultdict(int) def __str__(self) -> str: protocol_infos = ','.join(tuple( f"{proto.name}:{proto.version}" for proto in self.get_protocols() )) return f"Multiplexer[{protocol_infos}]" def __repr__(self) -> str: return f"<{self}>" # # Transport API # def get_transport(self) -> TransportAPI: return self._transport # # Message Counts # def get_total_msg_count(self) -> int: return sum(self._msg_counts.values()) # # Proxy Transport methods # @cached_property def remote(self) -> NodeAPI: return self._transport.remote @cached_property def session(self) -> SessionAPI: return self._transport.session @property def is_closing(self) -> bool: return self._transport.is_closing def close(self) -> None: self._transport.close() self.cancel_token.trigger() # # Protocol API # def has_protocol(self, protocol_identifier: Union[ProtocolAPI, Type[ProtocolAPI]]) -> bool: try: if isinstance(protocol_identifier, Protocol): self.get_protocol_by_type(type(protocol_identifier)) return True elif isinstance(protocol_identifier, type): self.get_protocol_by_type(protocol_identifier) return True else: raise TypeError( f"Unsupported protocol value: {protocol_identifier} of type " f"{type(protocol_identifier)}" ) except UnknownProtocol: return False def get_protocol_by_type(self, protocol_class: Type[TProtocol]) -> TProtocol: if issubclass(protocol_class, BaseP2PProtocol): return cast(TProtocol, self._base_protocol) for protocol in self._protocols: if type(protocol) is protocol_class: return cast(TProtocol, protocol) raise UnknownProtocol(f"No protocol found with type {protocol_class}") def get_base_protocol(self) -> BaseP2PProtocol: return self._base_protocol def get_protocols(self) -> Tuple[ProtocolAPI, ...]: return tuple(cons(self._base_protocol, self._protocols)) def get_protocol_for_command_type(self, command_type: Type[CommandAPI]) -> ProtocolAPI: supported_protocols = tuple( protocol for protocol in self.get_protocols() if protocol.supports_command(command_type) ) if len(supported_protocols) == 1: return supported_protocols[0] elif not supported_protocols: raise UnknownProtocol( f"Connection does not have any protocols that support the " f"request command: {command_type}" ) elif len(supported_protocols) > 1: raise ValidationError( f"Could not determine appropriate protocol for command: " f"{command_type}. Command was found in the " f"protocols {supported_protocols}" ) else: raise Exception("This code path should be unreachable") # # Streaming API # def stream_protocol_messages(self, protocol_identifier: Union[ProtocolAPI, Type[ProtocolAPI]], ) -> AsyncIterator[Tuple[CommandAPI, Payload]]: """ Stream the messages for the specified protocol. """ if isinstance(protocol_identifier, Protocol): protocol_class = type(protocol_identifier) elif isinstance(protocol_identifier, type) and issubclass(protocol_identifier, Protocol): protocol_class = protocol_identifier else: raise TypeError("Unknown protocol identifier: {protocol}") if not self.has_protocol(protocol_class): raise UnknownProtocol(f"Unknown protocol '{protocol_class}'") if self._protocol_locks.is_locked(protocol_class): raise Exception(f"Streaming lock for {protocol_class} is not free.") elif not self._multiplex_lock.locked(): raise Exception("Not multiplexed.") # Mostly a sanity check but this ensures we do better than accidentally # raising an attribute error in whatever race conditions or edge cases # potentially make the `_multiplex_token` unavailable. if not hasattr(self, '_multiplex_token'): raise Exception("No cancel token found for multiplexing.") # We do the wait_iter here so that the call sites in the handshakers # that use this don't need to be aware of cancellation tokens. return self.wait_iter( self._stream_protocol_messages(protocol_class), token=self._multiplex_token, ) async def _stream_protocol_messages(self, protocol_class: Type[Protocol], ) -> AsyncIterator[Tuple[CommandAPI, Payload]]: """ Stream the messages for the specified protocol. """ async with self._protocol_locks.lock(protocol_class): msg_queue = self._protocol_queues[protocol_class] if not hasattr(self, '_multiplex_token'): raise Exception("Multiplexer is not multiplexed") token = self._multiplex_token while not self.is_closing and not token.triggered: try: # We use an optimistic strategy here of using # `get_nowait()` to reduce the number of times we yield to # the event loop. Since this is an async generator it will # yield to the loop each time it returns a value so we # don't have to worry about this blocking other processes. yield msg_queue.get_nowait() except asyncio.QueueEmpty: yield await self.wait(msg_queue.get(), token=token) # # Message reading and streaming API # @asynccontextmanager async def multiplex(self) -> AsyncIterator[None]: """ API for running the background task that feeds individual protocol queues that allows each individual protocol to stream only its own messages. """ # We generate a new token for each time the multiplexer is used to # multiplex so that we can reliably cancel it without requiring the # master token for the multiplexer to be cancelled. async with self._multiplex_lock: multiplex_token = CancelToken( 'multiplex', loop=self.cancel_token.loop, ).chain(self.cancel_token) stop = asyncio.Event() self._multiplex_token = multiplex_token fut = asyncio.ensure_future(self._do_multiplexing(stop, multiplex_token)) # wait for the multiplexing to actually start try: yield finally: # # Prevent corruption of the Transport: # # On exit the `Transport` can be in a few states: # # 1. IDLE: between reads # 2. HEADER: waiting to read the bytes for the message header # 3. BODY: already read the header, waiting for body bytes. # # In the IDLE case we get a clean shutdown by simply signaling # to `_do_multiplexing` that it should exit which is done with # an `asyncio.EVent` # # In the HEADER case we can issue a hard stop either via # cancellation or the cancel token. The read *should* be # interrupted without consuming any bytes from the # `StreamReader`. # # In the BODY case we want to give the `Transport.recv` call a # moment to finish reading the body after which it will be IDLE # and will exit via the IDLE exit mechanism. stop.set() # If the transport is waiting to read the body of the message # we want to give it a moment to finish that read. Otherwise # this leaves the transport in a corrupt state. if self._transport.read_state is TransportState.BODY: try: await asyncio.wait_for(fut, timeout=1) except asyncio.TimeoutError: pass # After giving the transport an opportunity to shutdown # cleanly, we issue a hard shutdown, first via cancellation and # then via the cancel token. This should only end up # corrupting the transport in the case where the header data is # read but the body data takes too long to arrive which should # be very rare and would likely indicate a malicious or broken # peer. if fut.done(): fut.result() else: fut.cancel() try: await fut except asyncio.CancelledError: pass multiplex_token.trigger() del self._multiplex_token async def _do_multiplexing(self, stop: asyncio.Event, token: CancelToken) -> None: """ Background task that reads messages from the transport and feeds them into individual queues for each of the protocols. """ msg_stream = self.wait_iter(stream_transport_messages( self._transport, self._base_protocol, *self._protocols, token=token, ), token=token) async for protocol, cmd, msg in msg_stream: # track total number of messages received for each command type. self._msg_counts[type(cmd)] += 1 queue = self._protocol_queues[type(protocol)] try: # We must use `put_nowait` here to ensure that in the event # that a single protocol queue is full that we don't block # other protocol messages getting through. queue.put_nowait((cmd, msg)) except asyncio.QueueFull: self.logger.error( ( "Multiplexing queue for protocol '%s' full. " "discarding message: %s" ), protocol, cmd, ) if stop.is_set(): break