def _send_rst(self, to: AddressType): """ Send a RST datagram. :param to: address of target. """ hdr = PacketHeader(flags=PacketFlags.RST | PacketFlags.REPLY) self._transport.sendto(hdr.serialize(), to)
def __call__(self, packet: bytes) -> Optional[bytes]: """ Pass a packet through the dropper. :param packet: packet to pass. :return: ``None`` if the packet is to be dropped. Otherwise, the passed packet is returned. """ # Don't alter undecodable packets. try: hdr = PacketHeader.deserialize(packet) except Exception: return packet # Don't drop control messages. if any(map(hdr.flags.__contains__, self.CONTROL_FLAGS)): return packet if not hdr.client_id.value: return packet # Calculate drop probability v = random.random() if v >= self.prob: return packet return None
def reply_packet( orig_hdr: PacketHeader, flags: PacketFlags, status: ExecutionStatus, ret: bytes = b"", ) -> bytes: """ Generate a reply packet with an execution status and a method return value. :param orig_hdr: header of original request. :param flags: packet flags. :param status: RPC execution status. :param ret: serialized return value :return: packet data. """ out = bytearray( PacketHeader( orig_hdr.client_id, orig_hdr.trans_num, flags, orig_hdr.semantics, orig_hdr.method_ordinal, ).serialize()) out.append(status.value) out.extend(ret) return out
async def _ping_loop(self, interval: float): """ Loop that will send PING packets to the server at a predefined interval. :param interval: interval to send ping packets at (in seconds). """ while True: await asyncio.sleep(interval) self._send( PacketHeader(client_id=u32(self._cid), flags=PacketFlags.PING).serialize())
def _close(self, send_rst: bool = True): """ Close the connection. Safe to call if already closed. :param send_rst: whether to send a RST packet. """ if self.closed: return if send_rst: self._send(PacketHeader(flags=PacketFlags.RST).serialize()) if (t := self._ping_task) is not None: t.cancel()
async def _process_task(self, data: bytes): """ Task that actually processes the incoming datagram. See ``datagram_received()`` for more information on the parameters. """ try: hdr = PacketHeader.deserialize(data) payload = data[hdr.LENGTH:] except Exception: return if not hdr.is_reply: return if not self._connected_event.is_set(): # todo do we want to ignore RSTs? if not (hdr.flags & PacketFlags.CHANGE_CID): return self._cid = hdr.client_id.value self._connected_event.set() self._last_activity_time = time.monotonic() self._ping_task = asyncio.create_task( self._ping_loop(self._ping_interval)) self._inactivity_check_task = asyncio.create_task( self._inactivity_check_loop(self._inactivity_timeout)) return if hdr.flags & PacketFlags.RST: self._close(False) return if hdr.client_id.value != self._cid: return self._last_activity_time = time.monotonic() if hdr.flags & PacketFlags.PING: return self._router.route(hdr, payload)
def process(self, hdr: PacketHeader, payload: bytes): """ Process a packet received from the client. :param hdr: packet header. :param payload: packet payload. """ self._last_activity_time = time.monotonic() txid = hdr.trans_num.value if hdr.flags & PacketFlags.PING: hdr.flags |= PacketFlags.REPLY self._send_packet(hdr.serialize()) return # Ignore requests with txids corresponding to executing tasks. if txid in self._tmgr: return # Otherwise, create task to run oserver. self._tmgr.create_task(txid, self._oserver.process(hdr, payload))
async def call( self, ordinal: int, args: bytes, ) -> bytes: """ Call a remote method. Will wait for a connection to be established if not already connected. :param ordinal: ordinal of the remote method. :param args: serialized arguments for the remote method. :return: serialized return value for the remote method. :raises exceptions.RPCConnectionClosedError: if the connection has been closed. """ # timeout & retries will be checked by set_semantics(). if not self: await self.wait_connected() # cache all configuration parameters so they remain consistent # for this call. timeout = self.timeout if (self.timeout is not None) and (self.retries > 0): timeout /= self.retries tries = self.retries + 1 semantics = self.semantics # Generate initial header. tid = self._txid.next().copy() hdr = PacketHeader( u32(self._cid), trans_num=tid, semantics=semantics, method_ordinal=u32(ordinal), ) for i in range(tries): if i: hdr.flags |= PacketFlags.REPLAYED self._send(hdr.serialize() + args) # shouldn't race because it's single threaded, and we don't hit an await # till the future gets submitted. try: rhdr, payload = await asyncio.wait_for( self._router.listen(tid.value), timeout) except asyncio.TimeoutError: continue if not len(payload): raise exceptions.InvalidReplyError("zero-length payload") try: status = ExecutionStatus(payload[0]) except ValueError: raise exceptions.InvalidReplyError("execution status") if semantics is InvocationSemantics.AT_MOST_ONCE: # send acknowledgement to optimize result storage # doesn't matter if it gets lost because the server will age it # out anyway. hdr.flags = PacketFlags.ACK_REPLY self._send(hdr.serialize()) excc = estatus_to_exception(status) if excc is not None: raise excc() return payload[1:] else: raise asyncio.TimeoutError
def connection_made(self, transport: transports.DatagramTransport): self._transport = transport # Obtain an ID. self._send(PacketHeader().serialize())
def datagram_received(self, data: bytes, addr: AddressType): """ Process an incoming datagram. :param data: data received. :param addr: datagram source address. """ try: hdr = PacketHeader.deserialize(data) payload = data[hdr.LENGTH:] except ValueError: # Nothing we can do, packet cannot be decoded. return # Filter replies if hdr.is_reply: return # Check for an RST. if hdr.flags & PacketFlags.RST: # Disconnect active client. if addr in self._clients: self.disconnect_client(addr, False) return # Non RST, continuation of previous session or new connection. cid = hdr.client_id.value # New connection. if not cid: if addr in self._clients: # Delete stale connection for the same network address. self.disconnect_client(addr, False) skel = self._skel_fac(addr) # Generate random CID to avoid collisions with previously connected clients. cid = random.randint(u32.min() + 1, u32.max()) self._clients[addr] = cid, ConnectedClient( caddr=addr, skel=skel, transport=self._transport, timeout_callback=self.disconnect_client, inactivity_timeout=self._inactivity_timeout, result_cache_timeout=self._result_cache_timeout, ) # Register and send new CID rep = PacketHeader(client_id=u32(cid), flags=(PacketFlags.REPLY | PacketFlags.CHANGE_CID)) self._transport.sendto(rep.serialize(), addr) return # Existing connection if addr not in self._clients: # Unknown client # Client thinks it's using an open connection but server # has no record of that connection. Reset to signal that. self._send_rst(addr) return scid, cclient = self._clients[addr] if scid != cid: # Unknown client, also one that thinks it's using an open connection # but server also doesn't have a record of that connection. Reset. # todo: need to notify servers of shutdown! self.disconnect_client(addr) return cclient.process(hdr, payload)