Ejemplo n.º 1
0
def get_logger(name: str) -> eth_utils.ExtendedDebugLogger:
    """
    A wrapper around get_extended_debug_logger() that ensures we have loggers for all ancestors in
    the chain leading to the given name.  E.g. for name=='foo.bar.baz', this will ensure we have
    loggers named 'foo' and 'foo.bar' as direct and 2nd degree ancestors of the returned logger.

    This is necessary because otherwise a callsite could create the 'foo.bar.baz' logger before
    all its ancestors have been created, and that would cause it to have the root logger as parent,
    thus not inheriting the properties of the 'foo.bar' logger.
    """
    # TODO: Move this to eth_utils once we're confident it has no unintended side-effects.
    sub_names = name.split('.')
    for i in range(1, len(sub_names)):
        eth_utils.get_extended_debug_logger('.'.join(sub_names[:i]))
    return eth_utils.get_extended_debug_logger(name)
Ejemplo n.º 2
0
 def __init__(self, identity_scheme_registry: IdentitySchemeRegistry, db: DatabaseAPI) -> None:
     self.db = db
     self.logger = get_extended_debug_logger(".".join((
         self.__module__,
         self.__class__.__name__,
     )))
     self._identity_scheme_registry = identity_scheme_registry
Ejemplo n.º 3
0
class ETHProtocolV65(BaseETHProtocol):
    version = 65
    commands = (
        Status,
        NewBlockHashes,
        Transactions,
        GetBlockHeaders,
        BlockHeaders,
        GetBlockBodies,
        BlockBodies,
        NewBlock,
        NewPooledTransactionHashes,
        GetPooledTransactions,
        PooledTransactions,
        GetNodeData,
        NodeData,
        GetReceipts,
        Receipts,
    )
    command_length = 20

    logger = get_extended_debug_logger(
        'trinity.protocol.eth.proto.ETHProtocol')

    status_command_type = Status
Ejemplo n.º 4
0
    def __init__(
        self,
        inbound_envelope_receive_channel: trio.abc.ReceiveChannel[InboundEnvelope],
        inbound_message_receive_channel: trio.abc.ReceiveChannel[AnyInboundMessage],
        pool: PoolAPI,
        enr_db: ENRDatabaseAPI,
        events: EventsAPI = None,
    ) -> None:
        self.logger = get_extended_debug_logger("ddht.Dispatcher")

        self._pool = pool
        self._enr_db = enr_db

        if events is None:
            events = Events()
        self._events = events

        self._inbound_envelope_receive_channel = inbound_envelope_receive_channel
        self._inbound_message_receive_channel = inbound_message_receive_channel

        (
            self._outbound_message_send_channel,
            self._outbound_message_receive_channel,
        ) = trio.open_memory_channel[AnyOutboundMessage](256)

        self.subscription_manager = SubscriptionManager()

        self._reserved_request_ids = set()
        self._active_request_ids = set()
Ejemplo n.º 5
0
class NoopDiscoveryService(Service):
    'A stub "discovery service" which does nothing'
    logger = get_extended_debug_logger('p2p.discovery.NoopDiscoveryService')

    def __init__(self, event_bus: EndpointAPI) -> None:
        self._event_bus = event_bus

    async def handle_get_peer_candidates_requests(self) -> None:
        async for event in self._event_bus.stream(PeerCandidatesRequest):
            self.logger.debug("Servicing request for more peer candidates")

            await self._event_bus.broadcast(
                event.expected_response_type()(tuple()),
                event.broadcast_config())

    async def handle_get_random_bootnode_requests(self) -> None:
        async for event in self._event_bus.stream(RandomBootnodeRequest):
            self.logger.debug("Servicing request for boot nodes")

            await self._event_bus.broadcast(
                event.expected_response_type()(tuple()),
                event.broadcast_config())

    async def run(self) -> None:
        self.manager.run_daemon_task(self.handle_get_peer_candidates_requests)
        self.manager.run_daemon_task(self.handle_get_random_bootnode_requests)

        await self.manager.wait_finished()
Ejemplo n.º 6
0
class Handshaker(HandshakerAPI):
    """
    Base class that handles the handshake for a given protocol.  The primary
    justification for this class's existence is to house parameters that are
    needed for the protocol handshake.
    """
    logger = get_extended_debug_logger('p2p.handshake.Handshaker')
Ejemplo n.º 7
0
    def __init__(self,
                 network: NetworkProtocol,
                 target: NodeID,
                 concurrency: int = 3) -> None:
        self.logger = get_extended_debug_logger("ddht.Explorer")

        self._network = network
        self.target = target
        self._concurrency = concurrency

        self._condition = trio.Condition()

        self.in_flight = set()
        self.seen = set()
        self.queried = {self._network.local_node_id}
        self.unresponsive = set()
        self.unreachable = set()
        self.invalid = set()

        # Using a relatively small buffer size here ensures that we are applying
        # back-pressure against the workers.  If the consumer is only consuming a
        # few nodes, we don't need to continue issuing requests.
        self._send_channel, self._receive_channel = trio.open_memory_channel[
            ENRAPI](16)

        # signal that the initial set of nodes for exploration has been seeded.
        self._exploration_seeded = trio.Event()

        # signal that the service is up and running and ready for nodes to be streamed.
        self._ready = trio.Event()
Ejemplo n.º 8
0
    def __init__(
        self,
        outbound_envelope_receive_channel: ReceiveChannel[OutboundEnvelope],
        outbound_datagram_send_channel: SendChannel[OutboundDatagram],
    ) -> None:
        self.logger = get_extended_debug_logger("ddht.EnvelopeEncoder")

        self._outbound_envelope_receive_channel = outbound_envelope_receive_channel
        self._outbound_datagram_send_channel = outbound_datagram_send_channel
Ejemplo n.º 9
0
class MsgBuffer(PeerSubscriber):
    logger = get_extended_debug_logger('p2p.peer.MsgBuffer')
    msg_queue_maxsize = 500
    subscription_msg_types = frozenset({BaseCommand})

    @to_tuple
    def get_messages(self) -> Iterator[PeerMessage]:
        while not self.msg_queue.empty():
            yield self.msg_queue.get_nowait()
Ejemplo n.º 10
0
def setup_queue_logging(log_queue: 'Queue[str]', level: int) -> None:
    queue_handler = QueueHandler(log_queue)
    queue_handler.setLevel(level)

    logger = get_extended_debug_logger('')
    logger.addHandler(queue_handler)
    logger.setLevel(level)

    logger.debug('Logging initialized: PID=%s', os.getpid())
Ejemplo n.º 11
0
    class PausingVM(original_vm_class):  # type: ignore
        logger = get_extended_debug_logger(f'eth.vm.base.VM.{original_vm_class.__name__}')

        @classmethod
        def get_state_class(cls) -> Type[StateAPI]:
            return PausingVMState

        def get_beam_stats(self) -> BeamStats:
            return self.state.stats_counter
Ejemplo n.º 12
0
class BlockImportServer(Service):
    logger = get_extended_debug_logger('trinity.sync.beam.BlockImportServer')

    def __init__(
            self,
            event_bus: EndpointAPI,
            beam_chain: BeamChain) -> None:
        self._event_bus = event_bus
        self._beam_chain = beam_chain

    async def run(self) -> None:
        self.manager.run_daemon_task(self.serve, self._event_bus, self._beam_chain)
        await self.manager.wait_finished()

    async def serve(
            self,
            event_bus: EndpointAPI,
            beam_chain: BeamChain) -> None:
        """
        Listen to DoStatelessBlockImport events, and import block when received.
        Reply with StatelessBlockImportDone when import is complete.
        """

        loop = asyncio.get_event_loop()
        async for event in event_bus.stream(DoStatelessBlockImport):
            # launch in new thread, so we don't block the event loop!
            import_completion = loop.run_in_executor(
                # Maybe build the pausing chain inside the new process?
                None,
                partial_import_block(beam_chain, event.block),
            )

            # Wrapped in `asyncio.shield` because we want to hang the service from
            #   shutting down until block import is complete.
            # In the tests, for example, we await cancel() this service, so that we know
            #   that the in-progress block is complete. Then below, we do not send back
            #   the import completion (so the import server won't get triggered again).
            try:
                await asyncio.shield(import_completion)
            except StateUnretrievable as exc:
                self.logger.debug(
                    "Not broadcasting about %s Beam import. Listening for next request, because %r",
                    event.block,
                    exc
                )
            else:
                if self.manager.is_running:
                    _broadcast_import_complete(
                        event_bus,
                        event.block,
                        event.broadcast_config(),
                        import_completion,  # type: ignore
                    )

            if not self.manager.is_running:
                break
Ejemplo n.º 13
0
    def __init__(self, network: NetworkAPI) -> None:
        self.logger = get_extended_debug_logger("ddht.AlexandriaClient")

        self.network = network
        self.request_tracker = RequestTracker()
        self.subscription_manager = SubscriptionManager()

        network.add_talk_protocol(self)

        self._active_request_ids = set()
Ejemplo n.º 14
0
    def __init__(
        self,
        inbound_datagram_receive_channel: ReceiveChannel[InboundDatagram],
        inbound_envelope_send_channel: SendChannel[InboundEnvelope],
        local_node_id: NodeID,
    ) -> None:
        self.logger = get_extended_debug_logger("ddht.EnvelopeDecoder")

        self._inbound_datagram_receive_channel = inbound_datagram_receive_channel
        self._inbound_envelope_send_channel = inbound_envelope_send_channel
        self._local_node_id = local_node_id
Ejemplo n.º 15
0
async def DatagramSender(
    manager: ManagerAPI,
    outbound_datagram_receive_channel: ReceiveChannel[OutboundDatagram],
    sock: SocketType,
) -> None:
    """Take datagrams from a channel and send them via a socket to their designated receivers."""
    logger = get_extended_debug_logger("ddht.DatagramSender")

    async with outbound_datagram_receive_channel:
        async for datagram, endpoint in outbound_datagram_receive_channel:
            await send_datagram(sock, datagram, endpoint)
            logger.debug2("Sending %d bytes to %s", len(datagram), endpoint)
Ejemplo n.º 16
0
    def __init__(self, center_node_id: NodeID, bucket_size: int) -> None:
        self.logger = get_extended_debug_logger("ddht.KademliaRoutingTable")
        self.center_node_id = center_node_id
        self.bucket_size = bucket_size

        self.buckets: Tuple[Deque[NodeID], ...] = tuple(
            collections.deque(maxlen=bucket_size)
            for _ in range(NUM_ROUTING_TABLE_BUCKETS))
        self.replacement_caches: Tuple[Deque[NodeID], ...] = tuple(
            collections.deque() for _ in range(NUM_ROUTING_TABLE_BUCKETS))

        self.bucket_update_order: Deque[int] = collections.deque()
Ejemplo n.º 17
0
    def __init__(self, client: ClientAPI, bootnodes: Collection[ENRAPI],) -> None:
        self.logger = get_extended_debug_logger("ddht.Network")

        self.client = client

        self._bootnodes = tuple(bootnodes)
        self.routing_table = KademliaRoutingTable(
            self.client.enr_manager.enr.node_id, ROUTING_TABLE_BUCKET_SIZE,
        )
        self._routing_table_ready = trio.Event()
        self._last_pong_at = LRU(2048)

        self._talk_protocols = {}

        self._ping_handler_ready = trio.Event()
        self._find_nodes_handler_ready = trio.Event()
Ejemplo n.º 18
0
    def __init__(
        self,
        network: AlexandriaNetworkAPI,
        content_key: ContentKey,
        concurrency: int = 3,
    ) -> None:
        self.logger = get_extended_debug_logger("ddht.Seeker")

        self.concurrency = concurrency

        self.content_key = content_key
        self.content_id = content_key_to_content_id(content_key)

        self._network = network

        self._content_send, self.content_receive = trio.open_memory_channel[
            bytes](0)
Ejemplo n.º 19
0
class ConnectionTrackerServer(Service):
    """
    Server to handle the event bus communication for BlacklistEvent and
    GetBlacklistedPeersRequest/Response events
    """
    logger = get_extended_debug_logger(
        'trinity.components.network_db.ConnectionTrackerServer')

    def __init__(self, event_bus: EndpointAPI,
                 tracker: BaseConnectionTracker) -> None:
        self.tracker = tracker
        self.event_bus = event_bus

    async def handle_get_blacklisted_requests(self) -> None:
        async for req in self.event_bus.stream(GetBlacklistedPeersRequest):
            self.logger.debug2('Received get_blacklisted request')
            blacklisted = await self.tracker.get_blacklisted()
            await self.event_bus.broadcast(
                GetBlacklistedPeersResponse(blacklisted),
                req.broadcast_config())

    async def handle_blacklist_command(self) -> None:
        async for command in self.event_bus.stream(BlacklistEvent):
            self.logger.debug2(
                'Received blacklist commmand: remote: %s | timeout: %s | reason: %s',
                command.remote,
                humanize_seconds(command.timeout_seconds),
                command.reason,
            )
            self.tracker.record_blacklist(command.remote,
                                          command.timeout_seconds,
                                          command.reason)

    async def run(self) -> None:
        self.logger.debug("Running ConnectionTrackerServer")

        self.manager.run_daemon_task(
            self.handle_blacklist_command,
            name='ConnectionTrackerServer.handle_blacklist_command',
        )
        self.manager.run_daemon_task(
            self.handle_get_blacklisted_requests,
            name='ConnectionTrackerServer.handle_get_blacklisted_requests,',
        )

        await self.manager.wait_finished()
Ejemplo n.º 20
0
class StaticDiscoveryService(Service):
    """A 'discovery' service that only connects to the given nodes"""
    _static_peers: Tuple[NodeAPI, ...]
    _event_bus: EndpointAPI

    logger = get_extended_debug_logger('p2p.discovery.StaticDiscoveryService')

    def __init__(self, event_bus: EndpointAPI,
                 static_peers: Sequence[NodeAPI]) -> None:
        self._event_bus = event_bus
        self._static_peers = tuple(static_peers)

    async def handle_get_peer_candidates_requests(self) -> None:
        async for event in self._event_bus.stream(PeerCandidatesRequest):
            candidates = self._select_nodes(event.max_candidates)
            await self._broadcast_nodes(event, candidates)

    async def handle_get_random_bootnode_requests(self) -> None:
        async for event in self._event_bus.stream(RandomBootnodeRequest):
            candidates = self._select_nodes(1)
            await self._broadcast_nodes(event, candidates)

    def _select_nodes(self, max_nodes: int) -> Tuple[NodeAPI, ...]:
        if max_nodes >= len(self._static_peers):
            candidates = self._static_peers
            self.logger.debug2("Replying with all static nodes: %r",
                               candidates)
        else:
            candidates = tuple(random.sample(self._static_peers, max_nodes))
            self.logger.debug2("Replying with subset of static nodes: %r",
                               candidates)
        return candidates

    async def _broadcast_nodes(
            self, event: BaseRequestResponseEvent[PeerCandidatesResponse],
            nodes: Sequence[NodeAPI]) -> None:
        await self._event_bus.broadcast(
            event.expected_response_type()(tuple(nodes)),
            event.broadcast_config())

    async def run(self) -> None:
        self.manager.run_daemon_task(self.handle_get_peer_candidates_requests)
        self.manager.run_daemon_task(self.handle_get_random_bootnode_requests)

        await self.manager.wait_finished()
Ejemplo n.º 21
0
class HeaderCheckpointSyncer(HeaderSyncerAPI):
    """
    Wraps a "real" header syncer, and drops headers on the floor, until triggered
    at a "checkpoint".

    Return the headers at the cehckpoint, and then pass through all the headers
    subsequently found by the header syncer.

    Can be used by a body syncer to pause syncing until a header checkpoint is reached.
    """
    logger = get_extended_debug_logger(
        'trinity.sync.beam.chain.HeaderCheckpointSyncer')

    def __init__(self, passthrough: HeaderSyncerAPI) -> None:
        self._real_syncer = passthrough
        self._at_checkpoint = asyncio.Event()
        self._checkpoint_headers: Tuple[BlockHeader, ...] = None

    def set_checkpoint_headers(self, headers: Tuple[BlockHeader, ...]) -> None:
        """
        Identify the given headers as checkpoint headers. These will be returned first.

        Immediately after these checkpoint headers are returned, start consuming and
        passing through all headers from the wrapped header syncer.
        """
        self._checkpoint_headers = headers
        self._at_checkpoint.set()

    async def new_sync_headers(
            self,
            max_batch_size: int = None
    ) -> AsyncIterator[Tuple[BlockHeader, ...]]:
        await self._at_checkpoint.wait()

        self.logger.info("Choosing %s as checkpoint headers to sync from",
                         self._checkpoint_headers)
        yield self._checkpoint_headers

        async for headers in self._real_syncer.new_sync_headers(
                max_batch_size):
            yield headers

    def get_target_header_hash(self) -> Hash32:
        return self._real_syncer.get_target_header_hash()
Ejemplo n.º 22
0
class BaseIsolatedRequestServer(Service):
    """
    Monitor commands from peers, to identify inbound requests that should receive a response.
    Handle those inbound requests by querying our local database and replying.
    """
    logger = get_extended_debug_logger(
        'trinity.protocol.common.servers.IsolatedRequestServer')

    def __init__(
            self, event_bus: EndpointAPI, broadcast_config: BroadcastConfig,
            subscribed_events: Iterable[Type[PeerPoolMessageEvent]]) -> None:
        self.event_bus = event_bus
        self.broadcast_config = broadcast_config
        self._subscribed_events = subscribed_events

    async def run(self) -> None:
        for event_type in self._subscribed_events:
            self.manager.run_daemon_task(self.handle_stream, event_type)

        await self.manager.wait_finished()

    async def handle_stream(self,
                            event_type: Type[PeerPoolMessageEvent]) -> None:
        while self.manager.is_running:
            async for event in self.event_bus.stream(event_type):
                self.manager.run_task(self._quiet_handle_msg, event.session,
                                      event.command)

    async def _quiet_handle_msg(self, session: SessionAPI,
                                cmd: CommandAPI[Any]) -> None:
        try:
            await self._handle_msg(session, cmd)
        except asyncio.CancelledError:
            # catch and re-raise to avoid reporting via the except below and
            # treated as unexpected.
            raise
        except Exception:
            self.logger.exception(
                "Unexpected error when processing msg from %s", session)

    @abstractmethod
    async def _handle_msg(self, session: SessionAPI,
                          cmd: CommandAPI[Any]) -> None:
        ...
Ejemplo n.º 23
0
class HeaderCache:
    """
    The ``HeaderCache`` is responsible for holding on to all headers during validation until
    they are persisted in the database. This is necessary because validation in Clique depends
    on looking up past headers which may not be persisted at the time when they are needed.
    """

    logger = get_extended_debug_logger(
        'eth.consensus.clique.header_cache.HeaderCache')

    def __init__(self, chaindb: ChainDatabaseAPI) -> None:
        self._chaindb = chaindb
        self._cache: Dict[Hash32, BlockHeader] = {}
        self._gc_threshold = 1000

    def __getitem__(self, key: Hash32) -> BlockHeader:
        return self._cache[key]

    def __setitem__(self, key: Hash32, value: BlockHeader) -> None:
        self._cache[key] = value

    def __contains__(self, key: bytes) -> bool:
        return key in self._cache

    def __delitem__(self, key: Hash32) -> None:
        del self._cache[key]

    def __len__(self) -> int:
        return len(self._cache)

    def evict(self) -> None:
        """
        Evict all headers from the cache that have a block number lower than the oldest
        block number that is considered needed.
        """
        head = self._chaindb.get_canonical_head()
        oldest_needed_header = head.block_number - self._gc_threshold

        for header in list(self._cache.values()):
            if header.block_number < oldest_needed_header:
                self._cache.pop(header.hash)

        self.logger.debug2("Finished cache cleanup. Cache length: %s",
                           len(self))
Ejemplo n.º 24
0
class MetricsService(Service):
    """
    A service to provide a registry where metrics instruments can be registered and retrieved from.
    It continuously reports metrics to the specified InfluxDB instance.
    """
    def __init__(self,
                 influx_server: str,
                 influx_user: str,
                 influx_password: str,
                 influx_database: str,
                 host: str,
                 reporting_frequency: int = 10):

        self._influx_server = influx_server
        self._reporting_frequency = reporting_frequency
        self._registry = HostMetricsRegistry(host)
        self._reporter = InfluxReporter(registry=self._registry,
                                        protocol='https',
                                        port=443,
                                        database=influx_database,
                                        username=influx_user,
                                        password=influx_password,
                                        server=influx_server)

    logger = get_extended_debug_logger(
        'trinity.components.builtin.metrics.MetricsService')

    @property
    def registry(self) -> HostMetricsRegistry:
        """
        Return the :class:`trinity.components.builtin.metrics.registry.HostMetricsRegistry` at which
        metrics instruments can be registered and retrieved.
        """
        return self._registry

    async def run(self) -> None:
        self.logger.info("Reporting metrics to %s", self._influx_server)
        self.manager.run_daemon_task(self._continuously_report)
        await self.manager.wait_finished()

    async def _continuously_report(self) -> None:
        async for _ in trio_utils.every(self._reporting_frequency):
            self._reporter.report_now()
Ejemplo n.º 25
0
class GasMeter(GasMeterAPI):

    start_gas: int = None

    gas_refunded: int = None
    gas_remaining: int = None

    logger = get_extended_debug_logger('eth.gas.GasMeter')

    def __init__(
            self,
            start_gas: int,
            refund_strategy: RefundStrategy = default_refund_strategy) -> None:
        validate_uint256(start_gas, title="Start Gas")

        self.refund_strategy = refund_strategy
        self.start_gas = start_gas

        self.gas_remaining = self.start_gas
        self.gas_refunded = 0

    #
    # Write API
    #
    def consume_gas(self, amount: int, reason: str) -> None:
        if amount < 0:
            raise ValidationError("Gas consumption amount must be positive")

        if amount > self.gas_remaining:
            raise OutOfGas(f"Out of gas: Needed {amount} "
                           f"- Remaining {self.gas_remaining} "
                           f"- Reason: {reason}")

        self.gas_remaining -= amount

    def return_gas(self, amount: int) -> None:
        if amount < 0:
            raise ValidationError("Gas return amount must be positive")

        self.gas_remaining += amount

    def refund_gas(self, amount: int) -> None:
        self.gas_refunded = self.refund_strategy(self.gas_refunded, amount)
Ejemplo n.º 26
0
class ETHProtocol(BaseProtocol):
    name = 'eth'
    version = 63
    commands = (
        Status,
        NewBlockHashes,
        Transactions,
        GetBlockHeaders,
        BlockHeaders,
        GetBlockBodies,
        BlockBodies,
        NewBlock,
        GetNodeData,
        NodeData,
        GetReceipts,
        Receipts,
    )
    command_length = 17

    logger = get_extended_debug_logger(
        'trinity.protocol.eth.proto.ETHProtocol')
Ejemplo n.º 27
0
class ETHProtocolV64(BaseETHProtocol):
    version = 64
    commands = (
        Status,
        NewBlockHashes,
        Transactions,
        GetBlockHeadersV65,
        BlockHeadersV65,
        GetBlockBodiesV65,
        BlockBodiesV65,
        NewBlock,
        GetNodeDataV65,
        NodeDataV65,
        GetReceiptsV65,
        ReceiptsV65,
    )
    command_length = 17

    logger = get_extended_debug_logger(
        'trinity.protocol.eth.proto.ETHProtocolV64')
    status_command_type = Status
Ejemplo n.º 28
0
class ProxyLESAPI:
    """
    An ``LESAPI`` that can be used outside of the process that runs the peer pool. Any
    action performed on this class is delegated to the process that runs the peer pool.
    """
    logger = get_extended_debug_logger(
        'trinity.protocol.les.proxy.ProxyLESAPI')

    def __init__(self, session: SessionAPI, event_bus: EndpointAPI,
                 broadcast_config: BroadcastConfig):
        self.session = session
        self._event_bus = event_bus
        self._broadcast_config = broadcast_config

    def raise_if_needed(self, value: SupportsError) -> None:
        if value.error is not None:
            self.logger.warning(
                "Raised %s while fetching from peer %s",
                value.error,
                self.session,
            )
            raise value.error

    def send_block_headers(self,
                           headers: Sequence[BlockHeaderAPI],
                           buffer_value: int = 0,
                           request_id: int = None) -> int:
        if request_id is None:
            request_id = gen_request_id()
        command = BlockHeaders(
            BlockHeadersPayload(
                request_id=request_id,
                buffer_value=buffer_value,
                headers=tuple(headers),
            ))
        self._event_bus.broadcast_nowait(
            SendBlockHeadersEvent(self.session, command),
            self._broadcast_config,
        )
        return command.payload.request_id
Ejemplo n.º 29
0
    def __init__(
        self,
        local_private_key: bytes,
        local_node_id: NodeID,
        remote_endpoint: Endpoint,
        enr_db: ENRDatabaseAPI,
        inbound_message_send_channel: trio.abc.SendChannel[AnyInboundMessage],
        outbound_envelope_send_channel: trio.abc.SendChannel[OutboundEnvelope],
        message_type_registry: MessageTypeRegistry = v51_registry,
        events: Optional[EventsAPI] = None,
    ) -> None:
        self.logger = get_extended_debug_logger("ddht.Session")

        self.id = uuid.uuid4()

        self.created_at = trio.current_time()
        self._handshake_complete = trio.Event()

        if events is None:
            events = Events()

        self._events = events
        self._nonce_counter = itertools.count()

        self._local_private_key = local_private_key
        self._local_node_id = local_node_id
        self.remote_endpoint = remote_endpoint
        self._enr_db = enr_db

        self._message_type_registry = message_type_registry

        self._status = SessionStatus.BEFORE

        (
            self._outbound_message_buffer_send_channel,
            self._outbound_message_buffer_receive_channel,
        ) = trio.open_memory_channel[AnyOutboundMessage](256)

        self._inbound_message_send_channel = inbound_message_send_channel
        self._outbound_envelope_send_channel = outbound_envelope_send_channel
Ejemplo n.º 30
0
async def DatagramReceiver(
    manager: ManagerAPI,
    sock: SocketType,
    inbound_datagram_send_channel: SendChannel[InboundDatagram],
) -> None:
    """Read datagrams from a socket and send them to a channel."""
    logger = get_extended_debug_logger("ddht.DatagramReceiver")

    async with inbound_datagram_send_channel:
        while manager.is_running:
            datagram, (
                ip_address,
                port) = await sock.recvfrom(DISCOVERY_DATAGRAM_BUFFER_SIZE)
            endpoint = Endpoint(inet_aton(ip_address), port)
            logger.debug2("Received %d bytes from %s", len(datagram), endpoint)
            inbound_datagram = InboundDatagram(datagram, endpoint)
            try:
                await inbound_datagram_send_channel.send(inbound_datagram)
            except trio.BrokenResourceError:
                logger.debug(
                    "DatagramReceiver exiting due to `trio.BrokenResourceError`"
                )
                manager.cancel()
                return