async def remove_client(self, writer: asyncio.StreamWriter): ''' Close the cient input & output streams ''' if writer.can_write_eof(): writer.write_eof() writer.close() await writer.wait_closed() self._logger.info("Disconnected client")
def __check_writer(writer: asyncio.StreamWriter) -> bool: if writer is None: return False if hasattr(writer, "is_closing"): return not writer.is_closing() if writer.transport: return not writer.transport.is_closing() return writer.can_write_eof()
async def _connection_wrapper(application: core.ApplicationType, client_connections: "Set[asyncio.Task[None]]", container: core.Container, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: """ Run an ASGI application in an asyncio server. This function is suitable for passing to ``start_server`` or ``start_unix_server``, with the ``application`` and ``client_connections`` parameters bound via a ``functools.partial`` or similar. :param application: The ASGI application. :param client_connections: A set of Task objects, to which this connection is added on entry to and removed on exit from this function. :param container: The ASGI container to use. :param reader: The stream reader for the connection. :param writer: The stream writer for the connection. """ # Add this task to the set of open client connections. task = asyncio.current_task() assert task is not None, "_connection_wrapper must be called inside a task" client_connections.add(task) try: try: # aioscgi.run expects callables to read a chunk and write a chunk, # the latter taking a drain boolean; adapt the writing side to the # stream model (the reader is handled with a functools.partial). async def write_cb(data: bytes, drain: bool) -> None: writer.write(data) if drain: await writer.drain() # Run the application. await container.run(application, functools.partial(reader.read, io.DEFAULT_BUFFER_SIZE), write_cb) except Exception: # pylint: disable=broad-except logging.getLogger(__name__).error("Uncaught exception in application callable", exc_info=True) finally: # Close the connection. try: if writer.can_write_eof(): writer.write_eof() writer.close() except Exception: # pylint: disable=broad-except # If something went wrong while closing the connection, there’s # nothing interesting to report. pass finally: # Remove this task from the set of open client connections. client_connections.remove(task)
async def communicate( self, reader: StreamReader, writer: StreamWriter, ): protocol = Protocol(reader, writer) request_writer_task = self.create_task(self._request_writer(protocol)) client_host, client_port = writer.get_extra_info("peername")[:2] if ":" in client_host: client_host = f"[{client_host}]" log.info( "Start communication with tcp://%s:%d", client_host, client_port, ) # noinspection PyBroadException try: async for payload in protocol: if payload.request is not None: self.create_task(self.execute(protocol, payload.request), ) if payload.response is not None: self._on_response(payload.response) except Exception: log.exception( "Error when communication tcp://%s:%d/", client_host, client_port, ) finally: if writer.can_write_eof(): writer.write_eof() writer.close() log.info( "Communication with tcp://%s:%d finished", client_host, client_port, ) await cancel_tasks([request_writer_task]) await writer.wait_closed()
async def ahandle_peer(self, reader: StreamReader, writer: StreamWriter) -> None: """Read all DNS queries from the peer stream and schedule their resolution via a DnsResolver instance.""" tasks: Union[List[Task], Set[Task]] = [] wlock = aio.Lock() logging.debug(f'Got TCP DNS query stream from {writer.transport.get_extra_info("peername")}') while True: # Parse a DNS query packet off of the wire try: query_size = int.from_bytes(await reader.readexactly(2), 'big') query = await reader.readexactly(query_size) # Check if our peer has finished writing to the stream except aio.IncompleteReadError: break # Schedule the processing of the query tasks.append(aio.create_task(self.ahandle_query(writer, wlock, query))) # Wait for all scheduled query processing to finish while tasks: done, tasks = await aio.wait(tasks, return_when=aio.FIRST_COMPLETED) for task in done: error = task.exception() if error is not None: logging.warning(f'TCP DNS query resolution encountered an error - {error!r}') if not writer.is_closing(): # Indicate we are done writing to the stream if writer.can_write_eof(): writer.write_eof() # Close the stream writer.close() await writer.wait_closed()
async def __handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: proto = AsyncProtocol(reader, writer) packet_type, worker_id, digest, pid = await proto.receive() async with self.__closing_lock: if self.__closing: proto.close() if packet_type == PacketTypes.BAD_INITIALIZER: packet_type, exc = await proto.receive() if packet_type != PacketTypes.EXCEPTION: await proto.send(PacketTypes.BAD_PACKET) else: set_exception(self.__futures, exc) await self.close() return if packet_type != PacketTypes.AUTH: await proto.send(PacketTypes.BAD_PACKET) if writer.can_write_eof(): writer.write_eof() return if worker_id not in self.worker_ids: log.error("Unknown worker with id %r", worker_id) return expected_digest = hmac.HMAC( self.__cookie, worker_id, digestmod=hashlib.sha256, ).digest() if expected_digest != digest: await proto.send(PacketTypes.AUTH_FAIL) if writer.can_write_eof(): writer.write_eof() log.debug("Bad digest %r expected %r", digest, expected_digest) return await proto.send(PacketTypes.AUTH_OK) self._statistic.processes += 1 self._statistic.spawning += 1 self.pids.add(pid) try: while not reader.at_eof(): func: Callable args: Tuple[Any, ...] kwargs: Dict[str, Any] result_future: asyncio.Future process_future: asyncio.Future ( func, args, kwargs, result_future, process_future, ) = await self.tasks.get() if process_future.done() or result_future.done(): continue try: process_future.set_result(pid) await proto.send((PacketTypes.REQUEST, func, args, kwargs)) packet_type, payload = await proto.receive() if result_future.done(): log.debug( "Result future %r already done, skipping", result_future, ) continue if packet_type == PacketTypes.RESULT: result_future.set_result(payload) elif packet_type in ( PacketTypes.EXCEPTION, PacketTypes.CANCELLED, ): result_future.set_exception(payload) del packet_type, payload except (asyncio.IncompleteReadError, ConnectionError): if not result_future.done(): result_future.set_exception( ProcessError(f"Process {pid!r} unexpected exited"), ) break except Exception as e: if not result_future.done(): result_future.set_exception(e) if not writer.is_closing(): if writer.can_write_eof(): writer.write_eof() writer.close() raise finally: self._statistic.processes -= 1 self.pids.remove(pid)
async def client_handler(rx: asyncio.StreamReader, tx: asyncio.StreamWriter): process_rxs = {} client_tx_queue = asyncio.Queue() buffer = bytearray() event_loop = asyncio.get_event_loop() read_task = event_loop.create_task(rx.read(1024)) dequeue_task = event_loop.create_task(client_tx_queue.get()) pending_tasks = {read_task, dequeue_task} while pending_tasks and not rx.at_eof(): complete_tasks, pending_tasks = await asyncio.wait( pending_tasks, return_when=asyncio.FIRST_COMPLETED) if dequeue_task in complete_tasks: response = json.dumps(await dequeue_task).encode('utf-8') response_length = len(response).to_bytes(4, 'big') # send the response tx.write(response_length + response) await tx.drain() # recreate the dequeue task dequeue_task = event_loop.create_task(client_tx_queue.get()) pending_tasks.add(dequeue_task) if read_task in complete_tasks: buffer.extend(await read_task) # recreate the read task if we haven't reached EOF read_task = event_loop.create_task(rx.read(1024)) pending_tasks.add(read_task) # parse the buffer while len(buffer) >= FRAME_HEADER_LEN: message_length = int.from_bytes(buffer[:FRAME_HEADER_LEN], byteorder='big', signed=False) if len(buffer) < FRAME_HEADER_LEN + message_length: # more data needed, break the inner loop and await more data break # extract message from the buffer message = buffer[FRAME_HEADER_LEN:FRAME_HEADER_LEN + message_length] buffer = buffer[FRAME_HEADER_LEN + message_length:] # decode message try: message = json.loads(message) except JSONDecodeError: logger.warning('Could not decode message') continue # extract uuid and request uuid, request = message[0], message[1] if 'Reboot' in request: # send response response = json.dumps([uuid, 'Ok']).encode('utf-8') response_length = len(response).to_bytes(4, 'big') tx.write(response_length + response) await tx.drain() # execute reboot reboot = await asyncio.create_subprocess_exec("reboot") await reboot.wait() if 'Halt' in request: # send response response = json.dumps([uuid, 'Ok']).encode('utf-8') response_length = len(response).to_bytes(4, 'big') tx.write(response_length + response) await tx.drain() # execute halt halt = await asyncio.create_subprocess_exec("halt") await halt.wait() if 'Upload' in request: upload_request = request['Upload'] contents = bytes(upload_request['contents']) filename = upload_request['filename'] path = upload_request['path'] if path.endswith('/'): filepath = path + filename else: filepath = path + '/' + filename logger.info('uploading: %s', filepath) try: with open(filepath, 'wb') as upload: upload.write(contents) except Exception as error: response = [uuid, {'Error': str(error)}] await client_tx_queue.put(response) else: response = [uuid, 'Ok'] await client_tx_queue.put(response) if 'Process' in request: process_request = request['Process'] if 'Run' in process_request: run_process_request = process_request['Run'] # queues for communicating with this process process_rx = asyncio.Queue() process_tx = client_tx_queue process_rxs[uuid] = process_rx # run process process_coroutine = process(uuid, run_process_request, process_tx, process_rx) asyncio.get_event_loop().create_task(process_coroutine) if 'StandardInput' in process_request or 'Terminate' in process_request: if uuid in process_rxs: # forward the request to the process handler await process_rxs[uuid].put(process_request) # client handler done [task.cancel() for task in pending_tasks] if tx.can_write_eof(): tx.write_eof() await tx.drain() tx.close() await tx.wait_closed()
def fdms_session(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): online = None ''':type: (FdmsHeader, FdmsTransaction)''' add_on = None ''':type: (FdmsHeader, FdmsTransaction)''' offline = list() writer.write(bytes((ENQ,))) yield from writer.drain() while True: # Get Request attempt = 0 while True: try: if attempt > 4: return request = yield from asyncio.wait_for(read_fdms_packet(reader), timeout=15.0) if len(request) == 0: return control_byte = request[0] if control_byte == STX: lrs = functools.reduce(lambda x, y: x ^ int(y), request[2:-1], int(request[1])) if lrs != request[-1]: raise ValueError('LRS sum') pos, header = parse_header(request) txn = header.create_txn() txn.parse(request[pos:-2]) if header.txn_type == FdmsTransactionType.Online.value: if online is None: online = (header, txn) else: add_on = (header, txn) else: offline.append((header, txn)) if header.protocol_type == '2': break # Respond with ACK attempt = 0 writer.write(bytes((ACK,))) elif control_byte == EOT: break # Close session except asyncio.TimeoutError: return # Respond with NAK except Exception as e: logging.getLogger(LOG_NAME).debug('Request error: %s', str(e)) attempt += 1 writer.write(bytes((NAK,))) yield from writer.drain() if online is None: return # Process Transactions & Send Response for txn in offline: rs = process_txn(txn) offline.clear() if add_on is not None: process_add_on_txn(online, add_on) add_on = None rs = process_txn(online) # Send Response rs_bytes = rs.response() if rs.action_code == FdmsActionCode.HostSpecificPoll or rs.action_code == FdmsActionCode.RevisionInquiry: writer.write(rs_bytes) yield from writer.drain() else: attempt = 0 while True: if attempt >= 4: return writer.write(rs_bytes) yield from writer.drain() control_byte = 0 try: while True: rs_head = yield from asyncio.wait_for(reader.read(1), timeout=4.0) if len(rs_head) == 0: return control_byte = rs_head[0] & 0x7f if control_byte == ACK: break elif control_byte == NAK: break # Close session except asyncio.TimeoutError as e: return if control_byte == ACK: break else: attempt += 1 if online[0].wcc in {'B', 'C'}: # Send ENQ writer.write(bytes((ENQ,))) yield from writer.drain() continue else: break writer.write(bytes((EOT,))) yield from writer.drain() if writer.can_write_eof(): writer.write_eof()