async def shutdown(self) -> None: """ Gracefully shutdown the node manager. Disconnects all active nodes and stops service tasks. Note: This dependents on asyncio's Task canceling logic. It waits for all tasks to be cancelled and/or stopped before returning. """ self.shutting_down = True # shutdown all running tasks for this class # to prevent automatic filling of open spots after disconnecting nodes logger.debug("stopping tasks...") for t in self.tasks: t.cancel() await asyncio.gather(*self.tasks, return_exceptions=True) # finally disconnect all existing connections # we need to create a new list to loop over, because `disconnect` removes items from self.nodes to_disconnect = list(map(lambda n: n, self.nodes)) disconnect_tasks = [] logger.debug("disconnecting nodes...") for n in to_disconnect: disconnect_tasks.append( asyncio.create_task( n.disconnect(payloads.DisconnectReason.SHUTTING_DOWN))) await asyncio.gather(*disconnect_tasks, return_exceptions=True)
async def disconnect(self, reason: payloads.DisconnectReason) -> None: """ Close the connection to remote endpoint. Args: reason: reason for disconnecting. """ if self.disconnecting: return self.disconnecting = True logger.debug(f"Disconnect called with reason={reason.name}") self.address.disconnect_reason = reason if reason in [ payloads.DisconnectReason.MAX_CONNECTIONS_REACHED, payloads.DisconnectReason.POOR_PERFORMANCE, payloads.DisconnectReason.HANDSHAKE_VERACK_ERROR, payloads.DisconnectReason.HANDSHAKE_VERSION_ERROR, payloads.DisconnectReason.UNKNOWN ]: self.address.set_state_poor() elif reason == payloads.DisconnectReason.IPFILTER_NOT_ALLOWED: self.address.set_state_dead() for t in self.tasks: t.cancel() with suppress(asyncio.CancelledError): print(f"waiting for task to cancel {t}.") await t print("done") msgrouter.on_node_disconnected(self, reason) self.protocol.disconnect()
async def start(self, timeout=5) -> None: """ Start the block syncing service. Requires a started node manager. Args: timeout: time in seconds to wait for finding a started node manager. Raises: Exception: if no started Nodemanager is found within `timeout` seconds. """ async def wait_for_nodemanager(): logger.debug("Waiting for nodemanager to start") while not self.nodemgr.is_running: await asyncio.sleep(0.1) try: await asyncio.wait_for(wait_for_nodemanager(), timeout) except asyncio.TimeoutError: error_msg = f"Nodemanager failed to start within specified timeout {timeout}" logger.debug(error_msg) await self.shutdown() raise Exception(error_msg) logger.debug("Starting services") self._service_task = asyncio.create_task(self._run_service())
async def _monitor_node_height(self) -> None: now = datetime.utcnow().timestamp() for node in self.nodes: if now - node.best_height_last_update > self.MAX_HEIGHT_UPDATE_DURATION: logger.debug( f"Disconnecting node {node.nodeid} Reason: max height update threshold exceeded." ) asyncio.create_task( node.disconnect( reason=payloads.DisconnectReason.POOR_PERFORMANCE)) else: logger.debug( f"Asking node {node.nodeid_human} to send us a height update (PING)" ) # Request latest height from node if settings.database: height = max(0, blockchain.Blockchain().height) else: height = 0 m = message.Message( msg_type=message.MessageType.PING, payload=payloads.PingPayload(height=height)) task = asyncio.create_task(node.send_message(m)) self.tasks.append(task) task.add_done_callback(lambda fut: self.tasks.remove(fut))
def update_cache_for_block_persist(self, block: payloads.Block) -> None: for tx in block.transactions: with suppress(KeyError): self.cache.pop(tx.hash()) logger.debug( f"Found {tx.hash()} in last persisted block. Removing from relay cache" )
async def _process(seed): host, port = seed.split(':') if not is_ip_address(host): try: result = await resolver.query(host, 'A') host = result[0].host except aiodns.error.DNSError as e: logger.debug(f"Skipping {host}, address could not be resolved: {e}.") node.NeoNode.addresses.append(payloads.NetworkAddress(address=f"{host}:{port}"))
async def _query_addresses(self) -> None: """ Ask for the address list of connected nodes on an interval. """ logger.debug(f"Connected node count {len(self.nodes)}.") for node in self.nodes: logger.debug(f"Asking node {node.nodeid_human} for its address list") task = asyncio.create_task(node.request_address_list()) self.tasks.append(task) task.add_done_callback(lambda fut: self.tasks.remove(fut))
async def connection_lost(self, exc) -> None: """ Event called by the :meth:`base protocol <asyncio.BaseProtocol.connection_lost>`. """ logger.debug( f"{datetime.now()} Connection lost {self.address} exception: {exc}" ) if self.address.is_state_connected: await self.disconnect(payloads.DisconnectReason.UNKNOWN)
async def read_message(self, timeout: Optional[int] = 30) -> Optional[Message]: if timeout == 0: # avoid memleak. See: https://bugs.python.org/issue37042 timeout = None async def _read(): try: # readexactly can throw ConnectionResetError message_header = await self._stream_reader_orig.readexactly(3) payload_length = message_header[2] if payload_length == 0xFD: len_bytes = await self._stream_reader_orig.readexactly(2) payload_length, = struct.unpack("<H", len_bytes) elif payload_length == 0xFE: len_bytes = await self._stream_reader_orig.readexactly(4) payload_length, = struct.unpack("<I", len_bytes) elif payload_length == 0xFE: len_bytes = await self._stream_reader_orig.readexactly(8) payload_length, = struct.unpack("<Q", len_bytes) else: len_bytes = b'' if payload_length > Message.PAYLOAD_MAX_SIZE: raise ValueError("Invalid format") payload_data = await self._stream_reader_orig.readexactly(payload_length) raw = message_header + len_bytes + payload_data try: return Message.deserialize_from_bytes(raw) except Exception: logger.debug(f"Failed to deserialize message: {traceback.format_exc()}") return None except (ConnectionResetError, ValueError) as e: # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up self.client.disconnecting = True logger.debug(f"Failed to read message data for reason: {traceback.format_exc()}") return None except (asyncio.CancelledError, asyncio.IncompleteReadError): return None except Exception: # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up logger.debug(f"error read message 1 {traceback.format_exc()}") return None try: # logger.debug("trying to read message") return await asyncio.wait_for(_read(), timeout) except (asyncio.TimeoutError, asyncio.CancelledError): return None except Exception: logger.debug("error read message 2") traceback.print_exc() return None
async def _fill_open_connection_spots(self) -> None: open_spots = self.max_clients - (len(self.nodes) + len(self.queued_addresses)) if open_spots > 0: logger.debug( f"Found {open_spots} open pool spots, trying to add nodes...") # we sort the addresses such that nodes we recently disconnected from are last in the list # this matters in case we had to recycle addresses, meaning addresses with state POOR # are now labelled NEW again. node.NeoNode.addresses.sort(key=lambda addr: addr.last_connected) for _ in range(open_spots): # now we ask for the first address with the state NEW addr = node.NeoNode.get_address_new() if addr: # an address can be queued and its state not yet changed to CONNECTED, so we must make sure we're # not trying to connect to an address that is in an ongoing connection state if addr not in self.queued_addresses: logger.debug(f"Adding {addr} to connection queue.") self.queued_addresses.append(addr) if self._test_client_provider: socket_mock = next(self._test_client_provider()) task = asyncio.create_task( node.NeoNode.connect_to(socket=socket_mock)) else: task = asyncio.create_task( node.NeoNode.connect_to(addr.ip, addr.port)) self.tasks.append(task) task.add_done_callback(self._connect_done_cb) else: # oh no, we've exhausted our NEW addresses list if len(self.nodes) >= self.min_clients: logger.debug( f"No addresses available to fill spots. However, minimum clients still satisfied." ) break else: if self.MAX_NODE_POOL_ERROR_COUNT != self.MAX_NODE_POOL_ERROR: # give our `_query_addresses` loop a chance to collect new addresses from connected nodes self.MAX_NODE_POOL_ERROR_COUNT += 1 logger.debug( f"Increasing pool spot error count to {self.MAX_NODE_POOL_ERROR_COUNT}." ) break else: # we have no other option then to retry any address we know logger.debug("Recycling old addresses.") for addr in node.NeoNode.addresses: if addr.is_state_poor: addr.set_state_new() self.MAX_NODE_POOL_ERROR_COUNT = 0 break
def handler_pong(self, msg: message.Message) -> None: """ Handler for a message with the PONG type. Args: msg: """ payload = cast(payloads.PingPayload, msg.payload) logger.debug(f"Updating node {self.nodeid_human} height " f"from {self.best_height} to {payload.current_height}") self.best_height = payload.current_height self.best_height_last_update = datetime.utcnow().timestamp()
def increase_node_timeout_count(self, nodeid: int) -> None: """ Utility function to increase a node's `timeout_count` param by 1 and disconnect the node if it exceeds the threshold set by :attr:`~neo3.network.convenience.nodemanager.MAX_NODE_ERROR_COUNT`. Args: nodeid (:attr:`~neo3.network.node.NeoNode.nodeid`): the specific node to update. """ node = self.get_node_by_nodeid(nodeid) if node: node.nodeweight.timeout_count += 1 if node.nodeweight.timeout_count > self.MAX_NODE_TIMEOUT_COUNT: logger.debug(f"Disconnecting node {node.nodeid_human} Reason: max timeout count threshold exceeded.") asyncio.create_task(node.disconnect(reason=payloads.DisconnectReason.POOR_PERFORMANCE))
async def _process_incoming_data(self) -> None: """ Main loop """ logger.debug("Waiting for a message.") while not self.disconnecting: # we want to always listen for an incoming message m = await self.read_message(timeout=1) if m is None: continue handler = self.dispatch_table.get(m.type, None) if handler: handler(m) else: logger.debug(f"Unknown message with type: {m.type.name}.")
async def _sync_blocks(self) -> int: # to simplify syncing, don't ask for more data when there are still requests in flight if len(self.block_requests) > 0: return -1 block_cache_space = self.BLOCK_MAX_CACHE_SIZE - len(self.block_cache) if block_cache_space <= 0: return -2 try: best_node_height = max( map(lambda node: node.best_height, self.nodemgr.nodes)) except ValueError: # if the node list is empty max() fails on an empty list return -3 node = self.nodemgr.get_node_with_height(best_node_height) # if not node: # # no nodes with our desired height. We'll wait for node manager to resolve this # # or for the nodes to increase their height on the next produced block # return -4 best_block_height = self._get_best_stored_block_height() block_request_limit = min(block_cache_space, self.BLOCK_NETWORK_REQ_LIMIT) to_fetch_ctr = 0 for i in range(1, block_request_limit + 1): next_block_height = best_block_height + i if next_block_height > best_node_height: break self._add_block_flight_info(node.nodeid, next_block_height) to_fetch_ctr += 1 if to_fetch_ctr > 0: index_start = best_block_height + 1 logger.debug( f"Asking for blocks {index_start} - {index_start + to_fetch_ctr - 1} from {node.nodeid_human} " f"(blocks in cache: {len(self.block_cache)}).") await node.request_block_data(index_start=index_start, count=to_fetch_ctr) return 0
async def persist_blocks(self) -> None: self._is_persisting_blocks = True self.block_cache.sort(key=lambda b: b.index) try: while not self.shutting_down: try: block = self.block_cache.pop(0) except IndexError: # cache empty break await self.ledger.persist(block) await asyncio.sleep(0) except Exception as e: logger.debug( f"Unexpected exception happened while processing the block cache: {traceback.format_exc()}" ) finally: self._is_persisting_blocks = False
def _connect_done_cb(self, future) -> None: node_instance, failure = future.result() # failures here are hard failures from asyncio's loop.create_connection() if failure: logger.debug(f"Failed to connect to {failure[0]} reason: {failure[1]}.") tmp_addr = payloads.NetworkAddress(address=failure[0]) with suppress(ValueError): idx = node.NeoNode.addresses.index(tmp_addr) addr = node.NeoNode.addresses[idx] addr.set_state_dead() self.queued_addresses.remove(tmp_addr) msgrouter.on_client_connect_done(None, failure) else: msgrouter.on_client_connect_done(node_instance, None) node_instance.start_message_handler() self.tasks.remove(future)
async def connection_made(self, transport) -> None: """ Event called by the :meth:`base protocol <asyncio.BaseProtocol.connection_made>`. """ addr_tuple = self.protocol._stream_writer.get_extra_info('peername') address = f"{addr_tuple[0]}:{addr_tuple[1]}" network_addr = self._find_address_by_host_port(address) if network_addr: # this scenario occurs when the NodeManager queues seed nodes self.address = network_addr else: self.address.address = address if not ipfilter.is_allowed(addr_tuple[0]): logger.debug(f"Blocked by ipfilter: {self.address.address}") await self.disconnect( payloads.DisconnectReason.IPFILTER_NOT_ALLOWED) return
async def shutdown(self) -> None: """ Gracefully shutdown the sync manager. Stops block persisting and all service tasks. Note: This dependents on asyncio's Task canceling logic. It waits for all tasks to be cancelled and/or stopped before returning. """ logger.debug("Syncmanager shutting down") self.shutting_down = True self.block_cache = [] logger.debug("Stopping tasks...") if self._service_task: self._tasks.append(self._service_task) for t in self._tasks: t.cancel() await asyncio.gather(*self._tasks, return_exceptions=True)
def _payload_from_data(msg_type, data): with serialization.BinaryReader(data) as br: if msg_type in [MessageType.INV, MessageType.GETDATA]: return br.read_serializable(payloads.InventoryPayload) elif msg_type == MessageType.GETBLOCKBYINDEX: return br.read_serializable(payloads.GetBlockByIndexPayload) elif msg_type == MessageType.VERSION: return br.read_serializable(payloads.VersionPayload) elif msg_type == MessageType.VERACK: return br.read_serializable(payloads.EmptyPayload) elif msg_type == MessageType.BLOCK: return br.read_serializable(payloads.Block) elif msg_type == MessageType.HEADERS: return br.read_serializable(payloads.HeadersPayload) elif msg_type in [MessageType.PING, MessageType.PONG]: return br.read_serializable(payloads.PingPayload) elif msg_type == MessageType.ADDR: return br.read_serializable(payloads.AddrPayload) elif msg_type == MessageType.TRANSACTION: return br.read_serializable(payloads.Transaction) else: logger.debug(f"Unsupported payload {msg_type.name}")
async def _do_handshake( self) -> Tuple[bool, Optional[payloads.DisconnectReason]]: caps: List[capabilities.NodeCapability] = [ capabilities.FullNodeCapability(0) ] # TODO: fix nonce and port if a service is running send_version = message.Message(msg_type=message.MessageType.VERSION, payload=payloads.VersionPayload.create( nonce=123, user_agent="NEO3-PYTHON", capabilities=caps)) await self.send_message(send_version) m = await self.read_message(timeout=3) if not m or m.type != message.MessageType.VERSION: await self.disconnect( payloads.DisconnectReason.HANDSHAKE_VERSION_ERROR) return (False, payloads.DisconnectReason.HANDSHAKE_VERSION_ERROR) if not self._validate_version(m.payload): await self.disconnect( payloads.DisconnectReason.HANDSHAKE_VERSION_ERROR) return (False, payloads.DisconnectReason.HANDSHAKE_VERSION_ERROR) m_verack = message.Message(msg_type=message.MessageType.VERACK) await self.send_message(m_verack) m = await self.read_message(timeout=3) if not m or m.type != message.MessageType.VERACK: await self.disconnect( payloads.DisconnectReason.HANDSHAKE_VERACK_ERROR) return (False, payloads.DisconnectReason.HANDSHAKE_VERACK_ERROR) logger.debug( f"Connected to {self.version.user_agent} @ {self.address.address}: {self.best_height}." ) msgrouter.on_node_connected(self) return (True, None)
def handler_inv(self, msg: message.Message) -> None: """ Handler for a message with the INV type. Args: msg: """ payload = cast(payloads.InventoryPayload, msg.payload) if payload.type == payloads.InventoryType.BLOCK: # neo-cli broadcasts INV messages on a regular interval. We can use those as trigger to request # their latest block height if len(payload.hashes) > 0: if settings.database: height = max(0, blockchain.Blockchain().height) else: height = 0 m = message.Message( msg_type=message.MessageType.PING, payload=payloads.PingPayload(height=height)) self._create_task_with_cleanup(self.send_message(m)) else: logger.debug( f"Message with type INV received. No processing for payload type " # type:ignore f"{payload.type.name} implemented")
async def _read(): try: # readexactly can throw ConnectionResetError message_header = await self._stream_reader_orig.readexactly(3) payload_length = message_header[2] if payload_length == 0xFD: len_bytes = await self._stream_reader_orig.readexactly(2) payload_length, = struct.unpack("<H", len_bytes) elif payload_length == 0xFE: len_bytes = await self._stream_reader_orig.readexactly(4) payload_length, = struct.unpack("<I", len_bytes) elif payload_length == 0xFE: len_bytes = await self._stream_reader_orig.readexactly(8) payload_length, = struct.unpack("<Q", len_bytes) else: len_bytes = b'' if payload_length > Message.PAYLOAD_MAX_SIZE: raise ValueError("Invalid format") payload_data = await self._stream_reader_orig.readexactly( payload_length) raw = message_header + len_bytes + payload_data with serialization.BinaryReader(raw) as br: m = Message() try: m.deserialize(br) return m except Exception: logger.debug( f"Failed to deserialize message: {traceback.format_exc()}" ) return None except (ConnectionResetError, ValueError) as e: # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up self.client.disconnecting = True logger.debug( f"Failed to read message data for reason: {traceback.format_exc()}" ) return None except (asyncio.CancelledError, asyncio.IncompleteReadError): return None except Exception: # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up logger.debug(f"error read message 1 {traceback.format_exc()}") return None
def _validate_version(self, version) -> bool: if version.nonce == self.nodeid: logger.debug("Client is self.") return False if version.magic != settings.network.magic: logger.debug(f"Wrong network id {version.magic}.") return False for c in version.capabilities: if isinstance(c, capabilities.ServerCapability): addr = self._find_address_by_host_port(self.address.address) if addr: addr.set_state_connected() addr.capabilities = version.capabilities else: logger.debug( f"Adding address from outside {self.address.address}.") # new connection initiated from outside addr = payloads.address.NetworkAddress( address=self.address.address, capabilities=version.capabilities, state=payloads.address.AddressState.CONNECTED) self.addresses.append(addr) break for c in version.capabilities: if isinstance(c, capabilities.FullNodeCapability): # update nodes height indicator self.best_height = c.start_height self.best_height_last_update = datetime.utcnow().timestamp() self.version = version return True else: return False
async def wait_for_nodemanager(): logger.debug("Waiting for nodemanager to start") while not self.nodemgr.is_running: await asyncio.sleep(0.1)
async def connect_to( host: str = None, port: int = None, timeout=3, loop=None, socket=None ) -> Tuple[Optional[NeoNode], Optional[Tuple[str, str]]]: """ Establish a connection to a Neo node Note: performs the initial connection handshake and validation. Args: host: remote address in IPv4 format port: remote port timeout: maximum time establishing a connection may take loop: custom loop Raises: ValueError: if host/port and the socket argument as specified as the same time or none are specified. Returns: Tuple: - (Node instance, None) - if a connection was successfully established - (None, (ip address, error reason)) - if a connection failed to establish . Reasons include connection timeout, connection full and handshake errors. # noqa """ if loop is None: loop = asyncio.get_event_loop() if host is not None or port is not None: if socket is not None: raise ValueError( 'host/port and socket can not be specified at the same time' ) if socket is None and (host is None or port is None): raise ValueError( 'host and port was not specified and no sock specified') try: if socket: logger.debug(f"Trying to connect to socket: {socket}.") connect_coro = loop.create_connection(protocol.NeoProtocol, sock=socket) else: logger.debug(f"Trying to connect to: {host}:{port}.") connect_coro = loop.create_connection(protocol.NeoProtocol, host, port, family=IP4_FAMILY) transport, node = await asyncio.wait_for(connect_coro, timeout) success, fail_reason = await node.client._do_handshake() if success: return node.client, None else: raise Exception(fail_reason) except asyncio.TimeoutError: reason = f"Timed out" except OSError as e: reason = f"Failed to connect for reason {e}" except asyncio.CancelledError: reason = "Cancelled" except Exception as e: reason = traceback.format_exc() return None, (f"{host}:{port}", reason)
async def _check_timeout(self) -> int: """ This function checks if any of the outstanding data requests have exceeded the response time threshold. If so then the violating node is tagged. Next, a new node is selected to request the data we have not yet received in the hope that this node does perform adequately. """ if len(self.block_requests) == 0: # no outstanding data requests return -1 timedout_flights = dict() now = datetime.utcnow().timestamp() # find outstanding requests that timed out for height, request_info in self.block_requests.items(): flight_info = request_info.most_recent_flight() if flight_info and now - flight_info.start_time > self.BLOCK_REQUEST_TIMEOUT: timedout_flights[height] = flight_info if len(timedout_flights) == 0: # no timeouts, every request is still nicely within the threshold return -2 remaining_requests = [] nodes_to_tag_for_timeout = set() best_stored_block_height = self._get_best_stored_block_height() for height, flight_info in timedout_flights.items(): # adding to set to ensure we only tag nodes once per request nodes_to_tag_for_timeout.add(flight_info.node_id) try: request_info = self.block_requests[height] except KeyError: # TODO: check if still possible. After refactor should not be reachable anymore continue if flight_info.height <= best_stored_block_height: with suppress(KeyError): self.block_requests.pop(height) continue # tag the node for not delivering data within the set threshold request_info.mark_failed_node(flight_info.node_id) remaining_requests.append(request_info) for node_id in nodes_to_tag_for_timeout: # affect node weighting by increasing node timeout count self.nodemgr.increase_node_timeout_count(node_id) if len(remaining_requests) > 0: request_info_first = remaining_requests[0] request_info_last = remaining_requests[-1] # using the last request_info to find a suitable node, because the last request info is always the # highest block to look for node = self.nodemgr.get_least_failed_node(request_info_last) if node is None: # no connected nodes that can satisfy our request. # Return and let the node manager first resolve finding nodes return -3 # it is only possible to request block data by height (using the GetBlockData payload) for a consecutive # range. One option is to find these ranges and send a request for each range. Another option, which keeps # the code much simpler, is to just request the full range (from start to end height) and ignore any gaps # in the range that have been filled in the mean time by other nodes that timed out. # This leads to minimal (acceptable) additional traffic in certain scenarios. for request_info in remaining_requests: request_info.add_new_flight( convenience.FlightInfo(node.nodeid, request_info.height)) count = max(1, request_info_last.height - request_info_first.height) logger.debug( f"Block timeout for blocks {request_info_first.height} - {request_info_last.height}. " f"Trying again using next available node {node.nodeid_human}. " f"start={request_info_first.height}, count={count}.") await node.request_block_data( index_start=request_info_first.height, count=count) node.nodeweight.append_new_request_time() return 0