async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): nonlocal data header_size = struct.calcsize('!L') while not reader.at_eof(): try: header = await reader.readexactly(header_size) except asyncio.IncompleteReadError: break payload_size = struct.unpack("!L", header)[0] try: payload = await reader.readexactly(payload_size) except asyncio.IncompleteReadError: break for metric in pickle.loads(payload): data.append(metric) if len(data) == count: event.set() writer.close() reader.feed_eof()
async def _connect_streams(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, queue: "asyncio.Queue[int]", token: CancelToken) -> None: try: while not token.triggered: if reader.at_eof(): break try: size = queue.get_nowait() except asyncio.QueueEmpty: await asyncio.sleep(0) continue data = await token.cancellable_wait(reader.readexactly(size)) writer.write(data) queue.task_done() await token.cancellable_wait(writer.drain()) except OperationCancelled: pass finally: writer.write_eof() if reader.at_eof(): reader.feed_eof()
async def _client_accept( self, reader: StreamReader, writer: StreamWriter, read_ahead: bytes = None, ) -> None: """ Accept new clients and inform the tunnel about connections """ host, port = writer.get_extra_info("peername")[:2] ip = ipaddress.ip_address(host) # Block connections using the networks if self.block(ip): reader.feed_eof() writer.close() await writer.wait_closed() _logger.info("Connection from %s blocked", ip) return self.connections[ip].hits += 1 # Create the client object and generate an unique token client = Connection(reader, writer, self.protocol, utils.generate_token()) self.add(client) _logger.info("Client %s connected on %s:%s", client.uuid, host, port) # Inform the tunnel about the new client pkg = package.ClientInitPackage(ip, port, client.token) await self.tunnel.tun_write(pkg) # Send the buffer read ahead of initialization through the tunnel if read_ahead: await self.tunnel.tun_data(client.token, read_ahead) # Serve data from the client while True: data = await client.read(self.chunk_size) # Client disconnected. Inform the tunnel if not data: break await self.tunnel.tun_data(client.token, data) if self.server and self.server.is_serving(): pkg = package.ClientClosePackage(client.token) await self.tunnel.tun_write(pkg) await self._disconnect_client(client.token)
async def receive_handshake( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: ip, socket, *_ = writer.get_extra_info("peername") remote_address = Address(ip, socket) if self.peer_pool.chk_dialin_blacklist(remote_address): Logger.info_every_n( "{} has been blacklisted, refusing connection".format(remote_address), 100, ) reader.feed_eof() writer.close() expected_exceptions = ( TimeoutError, PeerConnectionLost, HandshakeFailure, asyncio.IncompleteReadError, HandshakeDisconnectedFailure, ) try: await self._receive_handshake(reader, writer) except expected_exceptions as e: self.logger.debug("Could not complete handshake: %s", e) Logger.error_every_n("Could not complete handshake: {}".format(e), 100) reader.feed_eof() writer.close() except OperationCancelled: self.logger.error("OperationCancelled") reader.feed_eof() writer.close() except Exception as e: self.logger.exception("Unexpected error handling handshake") reader.feed_eof() writer.close()
async def receive_handshake(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: try: try: await self._receive_handshake(reader, writer) except Exception: if not reader.at_eof(): reader.feed_eof() writer.close() raise except COMMON_RECEIVE_HANDSHAKE_EXCEPTIONS as e: self.logger.debug("Could not complete handshake: %s", e) except asyncio.CancelledError: # This exception should just bubble. raise except OperationCancelled: pass except Exception as e: self.logger.exception("Unexpected error handling handshake")
async def receive_handshake( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: try: try: await self._receive_handshake(reader, writer) except BaseException: if not reader.at_eof(): reader.feed_eof() writer.close() raise except COMMON_RECEIVE_HANDSHAKE_EXCEPTIONS as e: peername = writer.get_extra_info("peername") self.logger.debug("Could not complete handshake with %s: %s", peername, e) except asyncio.CancelledError: # This exception should just bubble. raise except Exception: peername = writer.get_extra_info("peername") self.logger.exception("Unexpected error handling handshake with %s", peername)
async def tcp_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): addr = writer.get_extra_info('peername') log.info("Client connected %r", addr) while not reader.at_eof(): try: async with async_timeout.timeout(5): line = await reader.readline() if line: parse_line(line) except (asyncio.CancelledError, asyncio.TimeoutError): log.info('Client connection closed after timeout') break except ConnectionResetError: log.warning('Client connection reset') reader.feed_eof() break log.info("Client disconnected %r", addr)
class MQTTClientProtocol(FlowControlMixin, asyncio.Protocol): def __init__(self, loop, config): super().__init__(loop=loop) self._loop = loop self._config = config self._transport = None self._write_pending_data_topic = [] # tuple (data, topic) self._connected = False self._encryptor = cryptor.Cryptor(self._config['password'], self._config['method']) self._peername = None self._reader_task = None self._data_task = None self._keepalive_task = None self._keepalive_timeout = self._config['timeout'] self._reader_ready = None self._reader_stopped = asyncio.Event(loop=self._loop) self._stream_reader = StreamReader(loop=self._loop) self._stream_writer = None self._reader = None self._topic_to_clients = {} self._queue = Queue(loop=loop) async def create_connection(self): try: # TODO handle pending task transport, protocol = await self._loop.create_connection(lambda: self, self._config['address'], self._config['port']) except OSError as e: logging.error("{0} when connecting to mqtt server({1}:{2})".format(e, self._config['address'], self._config['port'])) logging.error("Reconnection will be performed after 5s...") await asyncio.sleep(5) # TODO:retry interval self._loop.create_task(self.create_connection()) def connection_made(self, transport): self._peername = transport.get_extra_info('peername') self._transport = transport self._stream_reader.set_transport(transport) self._reader = StreamReaderAdapter(self._stream_reader) self._stream_writer = StreamWriter(transport, self, self._stream_reader, self._loop) self._loop.create_task(self.start()) def connection_lost(self, exc): logging.info("Lost connection with mqtt server{0}".format(self._peername)) super().connection_lost(exc) self._topic_to_clients = {} if self._stream_reader is not None: if exc is None: self._stream_reader.feed_eof() else: self._stream_reader.set_exception(exc) self.stop() self.reestablish_connection() def reestablish_connection(self): self._stream_reader = StreamReader(loop=self._loop) self._encryptor = cryptor.Cryptor(self._config['password'], self._config['method']) self._loop.call_later(5, lambda: self._loop.create_task(self.create_connection())) def data_received(self, data): self._stream_reader.feed_data(data) def eof_received(self): self._stream_reader.feed_eof() @asyncio.coroutine def consume(self): while self._transport is not None: packet = yield from self._queue.get() if packet is None: break if self._transport is None: break yield from self._send_packet(packet) @asyncio.coroutine def start(self): self._reader_ready = asyncio.Event(loop=self._loop) self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop) yield from self._reader_ready.wait() if self._keepalive_timeout: self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout) self._data_task = self._loop.create_task(self.consume()) # send connect packet connect_vh = ConnectVariableHeader(keep_alive=self._keepalive_timeout) connect_vh.password_flag = True password = self._encryptor.encrypt(self._encryptor.password.encode('utf-8')) connect_payload = ConnectPayload(client_id=ConnectPayload.gen_client_id(), password=password) connect_packet = ConnectPacket(vh=connect_vh, payload=connect_payload) yield from self._do_write(connect_packet) logging.info("Creating connection to mqtt server.") @asyncio.coroutine def stop(self): self._connected = False if self._keepalive_task: self._keepalive_task.cancel() self._data_task.cancel() logger.debug("waiting for tasks to be stopped") if not self._reader_task.done(): if not self._reader_stopped.is_set(): self._reader_task.cancel() # this will cause the reader_loop handle CancelledError # yield from asyncio.wait( # [self._reader_stopped.wait()], loop=self._loop) else: # caused by reader_loop break statement if self._transport: self._transport.close() self._transport = None @asyncio.coroutine def _reader_loop(self): running_tasks = collections.deque() while True: try: self._reader_ready.set() while running_tasks and running_tasks[0].done(): running_tasks.popleft() if len(running_tasks) > 1: logging.debug("{} Handler running tasks: {}".format(self._peername, len(running_tasks))) fixed_header = yield from asyncio.wait_for( MQTTFixedHeader.from_stream(self._reader), self._keepalive_timeout + 10, loop=self._loop) if fixed_header: if fixed_header.packet_type == RESERVED_0 or fixed_header.packet_type == RESERVED_15: logging.warning("{} Received reserved packet, which is forbidden: closing connection".format(self._peername)) break else: cls = packet_class(fixed_header) packet = yield from cls.from_stream(self._reader, fixed_header=fixed_header) task = None if packet.fixed_header.packet_type == CONNACK: task = ensure_future(self.handle_connack(packet), loop=self._loop) elif packet.fixed_header.packet_type == PINGREQ: task = ensure_future(self.handle_pingreq(packet), loop=self._loop) elif packet.fixed_header.packet_type == PINGRESP: task = ensure_future(self.handle_pingresp(packet), loop=self._loop) elif packet.fixed_header.packet_type == PUBLISH: # task = ensure_future(self.handle_publish(packet), loop=self._loop) self.handle_publish(packet) # elif packet.fixed_header.packet_type == SUBSCRIBE: # task = ensure_future(self.handle_subscribe(packet), loop=self._loop) # elif packet.fixed_header.packet_type == UNSUBSCRIBE: # task = ensure_future(self.handle_unsubscribe(packet), loop=self._loop) # elif packet.fixed_header.packet_type == SUBACK: # task = ensure_future(self.handle_suback(packet), loop=self._loop) # elif packet.fixed_header.packet_type == UNSUBACK: # task = ensure_future(self.handle_unsuback(packet), loop=self._loop) elif packet.fixed_header.packet_type == DISCONNECT: task = ensure_future(self.handle_disconnect(packet), loop=self._loop) else: logging.warning("{} Unhandled packet type: {}".format(self._peername, packet.fixed_header.packet_type)) if task: running_tasks.append(task) else: logging.debug("{} No more data (EOF received), stopping reader coro".format(self._peername)) break except MQTTException: logging.debug("{} Message discarded".format(self._peername)) except asyncio.CancelledError: # logger.debug("Task cancelled, reader loop ending") break except asyncio.TimeoutError: logging.debug("{} Input stream read timeout".format(self._peername)) break except NoDataException: logging.debug("{} No data available".format(self._peername)) except BaseException as e: logging.warning( "{}:{} Unhandled exception in reader coro: {}".format(type(self).__name__, self._peername, e)) break while running_tasks: running_tasks.popleft().cancel() self._reader_stopped.set() logging.debug("{} Reader coro stopped".format(self._peername)) yield from self.stop() def write(self, data: bytes, topic): if not self._connected: self._write_pending_data_topic.append((data, topic)) if len(self._write_pending_data_topic) > 50: self._write_pending_data_topic.clear() else: data = self._encryptor.encrypt(data) packet = PublishPacket.build(topic, data, None, dup_flag=0, qos=0, retain=0) ensure_future(self._do_write(packet), loop=self._loop) def write_eof(self, topic): packet = PublishPacket.build(topic, b'', None, dup_flag=0, qos=0, retain=1) ensure_future(self._do_write(packet), loop=self._loop) @asyncio.coroutine def _do_write(self, packet): yield from self._queue.put(packet) @asyncio.coroutine def _send_packet(self, packet): try: yield from packet.to_stream(self._stream_writer) except ConnectionResetError: return self._keepalive_task.cancel() self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout) def handle_write_timeout(self): packet = PingReqPacket() # TODO: check transport self._transport.write(packet.to_bytes()) self._keepalive_task.cancel() self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout) def handle_read_timeout(self): self._loop.create_task(self.stop()) @asyncio.coroutine def handle_connack(self, connack: ConnackPacket): if connack.variable_header.return_code == 0: self._connected = True logging.info("Connection to mqtt server established!") if len(self._write_pending_data_topic) > 0: self._keepalive_task.cancel() for data, topic in self._write_pending_data_topic: data = self._encryptor.encrypt(data) packet = PublishPacket.build(topic, data, None, dup_flag=0, qos=0, retain=0) yield from self._do_write(packet) self._write_pending_data_topic = [] self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout) else: logging.info("Unable to create connection to mqtt server! Shuting down...") self._loop.create_task(self.stop()) # @asyncio.coroutine def handle_publish(self, publish_packet: PublishPacket): data = bytes(publish_packet.data) server = self._topic_to_clients.get(publish_packet.topic_name, None) if server is None: logging.info("Received unregistered publish topic({0}) from mqtt server, packet will be ignored.".format( publish_packet.topic_name)) if not publish_packet.retain_flag: # retain=1 indicate we should close the client connection data = self._encryptor.decrypt(data) if server is not None: server.write(data) else: if server is not None: server.close(force=True) @asyncio.coroutine def handle_pingresp(self, pingresp: PingRespPacket): logging.info("Received PingRespPacket from mqtt server.") @asyncio.coroutine def handle_pingreq(self, pingreq: PingReqPacket): logging.info("Received PingReqPacket from mqtt server, Replying PingResqPacket.") ping_resp = PingRespPacket() yield from self._do_write(ping_resp) def register_client_topic(self, topic, server): self._topic_to_clients[topic] = server def unregister_client_topic(self, topic): self._topic_to_clients.pop(topic, None)
class MQTTServerProtocol(FlowControlMixin, asyncio.Protocol): def __init__(self, loop, config): super().__init__(loop=loop) self._loop = loop self._transport = None self._encryptor = cryptor.Cryptor(config['password'], config['method']) self._topic_to_remote = {} self._peername = None self._reader_task = None self._data_task = None self._keepalive_task = None self._keepalive_timeout = config['timeout'] self._reader_ready = None self._reader_stopped = asyncio.Event(loop=self._loop) self._stream_reader = StreamReader(loop=self._loop) self._stream_writer = None self._reader = None self._approved = False self._queue = Queue(loop=loop) def connection_made(self, transport): self._peername = transport.get_extra_info('peername') self._transport = transport logging.info("Mqtt client connected from: {}.".format(self._peername)) self._stream_reader.set_transport(transport) self._reader = StreamReaderAdapter(self._stream_reader) self._stream_writer = StreamWriter(transport, self, self._stream_reader, self._loop) self._loop.create_task(self.start()) def connection_lost(self, exc): logging.info("Mqtt client connection{} lost.".format(self._peername)) super().connection_lost(exc) if self._stream_reader is not None: if exc is None: self._stream_reader.feed_eof() else: self._stream_reader.set_exception(exc) self.stop() def data_received(self, data): self._stream_reader.feed_data(data) def eof_received(self): self._stream_reader.feed_eof() @asyncio.coroutine def consume(self): while self._transport is not None: packet = yield from self._queue.get() if packet is None: break if self._transport is None: break yield from self._send_packet(packet) @asyncio.coroutine def start(self): self._reader_ready = asyncio.Event(loop=self._loop) self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop) yield from self._reader_ready.wait() if self._keepalive_timeout: self._keepalive_task = self._loop.call_later( self._keepalive_timeout, self.handle_write_timeout) self._data_task = self._loop.create_task(self.consume()) @asyncio.coroutine def stop(self): if self._keepalive_task: self._keepalive_task.cancel() self._data_task.cancel() logger.debug("waiting for tasks to be stopped") if not self._reader_task.done(): if not self._reader_stopped.is_set(): self._reader_task.cancel( ) # this will cause the reader_loop handle CancelledError # yield from asyncio.wait( # [self._reader_stopped.wait()], loop=self._loop) else: # caused by reader_loop break statement if self._transport: self._transport.close() self._transport = None for topic, remote in self._topic_to_remote.items(): remote.close() @asyncio.coroutine def _reader_loop(self): running_tasks = collections.deque() while True: try: self._reader_ready.set() while running_tasks and running_tasks[0].done(): running_tasks.popleft() if len(running_tasks) > 1: logging.debug("{} Handler running tasks: {}".format( self._peername, len(running_tasks))) fixed_header = yield from asyncio.wait_for( MQTTFixedHeader.from_stream(self._reader), self._keepalive_timeout + 10, loop=self._loop) if fixed_header: if fixed_header.packet_type == RESERVED_0 or fixed_header.packet_type == RESERVED_15: logging.warning( "{} Received reserved packet, which is forbidden: closing connection" .format(self._peername)) break else: cls = packet_class(fixed_header) packet = yield from cls.from_stream( self._reader, fixed_header=fixed_header) task = None if packet.fixed_header.packet_type == CONNECT: task = ensure_future(self.handle_connect(packet), loop=self._loop) elif packet.fixed_header.packet_type == PINGREQ: task = ensure_future(self.handle_pingreq(packet), loop=self._loop) elif packet.fixed_header.packet_type == PINGRESP: task = ensure_future(self.handle_pingresp(packet), loop=self._loop) elif packet.fixed_header.packet_type == PUBLISH: # task = ensure_future(self.handle_publish(packet), loop=self._loop) self.handle_publish(packet) # elif packet.fixed_header.packet_type == SUBSCRIBE: # task = ensure_future(self.handle_subscribe(packet), loop=self._loop) # elif packet.fixed_header.packet_type == UNSUBSCRIBE: # task = ensure_future(self.handle_unsubscribe(packet), loop=self._loop) # elif packet.fixed_header.packet_type == SUBACK: # task = ensure_future(self.handle_suback(packet), loop=self._loop) # elif packet.fixed_header.packet_type == UNSUBACK: # task = ensure_future(self.handle_unsuback(packet), loop=self._loop) elif packet.fixed_header.packet_type == DISCONNECT: task = ensure_future( self.handle_disconnect(packet), loop=self._loop) else: # TODO: handle unknow packet type logging.warning( "{} Unhandled packet type: {}".format( self._peername, packet.fixed_header.packet_type)) if task: running_tasks.append(task) else: logging.debug( "{} No more data (EOF received), stopping reader coro". format(self._peername)) break except MQTTException: logging.debug("{} Message discarded".format(self._peername)) except asyncio.CancelledError: # logger.debug("Task cancelled, reader loop ending") break except asyncio.TimeoutError: logging.debug("{} Input stream read timeout".format( self._peername)) break except NoDataException: logging.debug("{} No data available".format(self._peername)) except BaseException as e: logging.warning( "{}:{} Unhandled exception in reader coro: {}".format( type(self).__name__, self._peername, e)) break while running_tasks: running_tasks.popleft().cancel() self._reader_stopped.set() logging.debug("{} Reader coro stopped".format(self._peername)) yield from self.stop() # for remote read def write(self, data, client_topic): data = self._encryptor.encrypt(data) packet = PublishPacket.build(client_topic, data, None, dup_flag=0, qos=0, retain=0) ensure_future(self._do_write(packet), loop=self._loop) def _write_eof(self, client_topic): packet = PublishPacket.build(client_topic, b'', None, dup_flag=0, qos=0, retain=1) ensure_future(self._do_write(packet), loop=self._loop) @asyncio.coroutine def _do_write(self, packet): yield from self._queue.put(packet) @asyncio.coroutine def _send_packet(self, packet): yield from packet.to_stream(self._stream_writer) self._keepalive_task.cancel() self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout) def handle_write_timeout(self): packet = PingReqPacket() self._transport.write(packet.to_bytes()) self._keepalive_task.cancel() self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout) def handle_read_timeout(self): self._loop.create_task(self.stop()) @asyncio.coroutine def handle_connect(self, connect: ConnectPacket): return_code = 0 self._approved = True password = self._encryptor.decrypt(connect.password) password = password.decode('utf-8') if password != self._encryptor.password: return_code = 4 self._approved = False logging.warning( "Invalid ConnectPacket password from mqtt client connection{}!" .format(self._peername)) connack_vh = ConnackVariableHeader(return_code=return_code) connack = ConnackPacket(variable_header=connack_vh) yield from self._do_write(connack) if return_code != 0: self._loop.create_task(self.stop()) # @asyncio.coroutine def handle_publish(self, publish_packet: PublishPacket): if not self._approved: self._loop.create_task(self.stop()) return data = bytes(publish_packet.data) remote = self._topic_to_remote.get(publish_packet.topic_name, None) if not publish_packet.retain_flag: data = self._encryptor.decrypt(data) if remote is None: # we are in STAGE_ADDR if not data: self._write_eof(publish_packet.topic_name) return header_result = common.parse_header(data) if header_result is None: logging.error( "Can not parse header when handling mqtt client({}) connection{}." .format(publish_packet.topic_name, self._peername)) self._write_eof(publish_packet.topic_name) return addrtype, remote_addr, remote_port, header_length = header_result logging.info( "Connecting to remote {}:{} from mqtt client({}) connection{}." .format(common.to_str(remote_addr), remote_port, publish_packet.topic_name, self._peername)) remote = RelayRemoteProtocol(self._loop, self, publish_packet.topic_name) self._topic_to_remote[publish_packet.topic_name] = remote self._loop.create_task( self.create_connection(remote, common.to_str(remote_addr), remote_port)) if len(data) > header_length: remote.write(data[header_length:]) else: # now in STAGE_STREAM remote.write(data) else: if remote is not None: remote.close(force=True) @asyncio.coroutine def handle_pingresp(self, pingresp: PingRespPacket): logging.info("Received PingRespPacket from mqtt client.") @asyncio.coroutine def handle_pingreq(self, pingreq: PingReqPacket): logging.info( "Received PingRepPacket from mqtt client, replying PingRespPacket." ) ping_resp = PingRespPacket() yield from self._do_write(ping_resp) async def create_connection(self, remote, host, port): try: #TODO handle pending task transport, protocol = await self._loop.create_connection( lambda: remote, host, port) except OSError as e: logging.error( "{} when creating remote connection to {}:{} from mqtt connection{}." .format(e, host, port, self._peername)) self.remove_topic(remote.client_topic) def remove_topic(self, topic): if self._transport is not None: self._write_eof(topic) self._topic_to_remote.pop(topic, None)
async def close(self, reader: StreamReader, writer: StreamWriter) -> None: """ Close a StreamReader and StreamWriter """ reader.feed_eof() writer.close() await writer.wait_closed()
def disconnect(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: if not reader.at_eof(): reader.feed_eof() writer.close()
class CommandProtocol(asyncio.Protocol): def __init__(self, tcp_keepalive_time=None, tcp_keepalive_interval=None, tcp_keepalive_probes=None, tcp_user_timeout_seconds=None, *args, **kwargs): self.transport = None self._reader = None """ https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/ http://coryklein.com/tcp/2015/11/25/custom-configuration-of-tcp-socket-keep-alive-timeouts.html Keep-Alive Process There are three configurable properties that determine how Keep-Alives work. On Linux they are1: tcp_keepalive_time default 7200 seconds tcp_keepalive_probes default 9 tcp_keepalive_intvl default 75 seconds The process works like this: 1. Client opens TCP connection 2. If the connection is silent for tcp_keepalive_time seconds, send a single empty ACK packet.1 3. Did the server respond with a corresponding ACK of its own? - No 1. Wait tcp_keepalive_intvl seconds, then send another ACK 2. Repeat until the number of ACK probes that have been sent equals tcp_keepalive_probes. 3. If no response has been received at this point, send a RST and terminate the connection. - Yes: Return to step 2 """ self.tcp_keepalive_time = int(tcp_keepalive_time) self.tcp_keepalive_interval = int(tcp_keepalive_interval) self.tcp_keepalive_probes = int(tcp_keepalive_probes) self.tcp_user_timeout_seconds = int(tcp_user_timeout_seconds) super().__init__(*args, **kwargs) def data_received(self, data): if self.kill_active: return self._reader.feed_data(data) def connection_made(self, transport): if self.kill_active: return sock = transport.get_extra_info('socket') set_tcp_keepalive(sock, opts=dict( tcp_keepalive=True, tcp_keepalive_idle=self.tcp_keepalive_time, tcp_keepalive_intvl=self.tcp_keepalive_interval, tcp_keepalive_cnt=self.tcp_keepalive_probes, )) # https://eklitzke.org/the-caveats-of-tcp-nodelay sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) if hasattr(socket, 'TCP_USER_TIMEOUT'): logger.debug('Setting TCP_USER_TIMEOUT to %s', self.tcp_user_timeout_seconds * 1000) sock.setsockopt(socket.SOL_TCP, socket.TCP_USER_TIMEOUT, self.tcp_user_timeout_seconds * 1000) self.transport = transport self._reader = StreamReader() self._reader.set_transport(transport) loop = asyncio.get_event_loop() loop.create_task(self.receive_commands()) super().connection_made(transport) def connection_lost(self, exc): if self.kill_active: return if self._reader is not None: if exc is None: self._reader.feed_eof() else: self._reader.set_exception(exc) super().connection_lost(exc) def eof_received(self): if self.kill_active: return self._reader.feed_eof() return super().eof_received() async def _read_int(self): if self.kill_active: return line = await self._reader.readuntil(b'\r\n') return int(line.decode().strip()) async def _read_float(self): if self.kill_active: return line = await self._reader.readuntil(b'\r\n') return float(line.decode().strip()) async def _read_bytes(self): if self.kill_active: return len = await self._read_int() line = await self._reader.readexactly(max(2, len + 2)) if line[-1] != ord(b'\n'): raise Exception(r"line[-1] != ord(b'\n')") if line[-2] != ord(b'\r'): raise Exception(r"line[-2] != ord(b'\r')") if len < 0: return None if len == 0: return b'' return line[:-2] async def _read_array(self): if self.kill_active: return len = await self._read_int() r = [] while len: c = await self._reader.readexactly(1) value = await self._receive_resp(c) r.append(value) len -= 1 return r async def _receive_resp(self, c): if self.kill_active: return if c == b':': return await self._read_int() elif c == b'$': return await self._read_bytes() elif c == b'*': return await self._read_array() elif c == b',': return await self._read_float() else: raise Exception('Unknown RESP start char %s' % c) async def receive_commands(self): while True: if self.kill_active: self.reply_terminating() return try: c = await self._reader.readexactly(1) if c in b':*$,': value = await self._receive_resp(c) if not isinstance(value, list): value = [ value, ] else: command = c + await self._reader.readuntil(b'\r\n') value = [ x.strip().encode() for x in command.decode().split(' ') ] except (ConnectionAbortedError, ConnectionResetError, IncompleteReadError, TimeoutError) as e: # Connection is closed self.receive_commands_end(e) return await self.on_command_received(*value) async def on_command_received(self, command): raise NotImplemented() def reply_terminating(self): raise NotImplemented() def receive_commands_end(self, exc): raise NotImplemented()
class ScriptedModem(SerialDeviceFactory, Protocol, AsyncioService): """ ScriptedModem emulates the behavior of a serial modem by following a pre-programmed script. It expects commands to be issued in a certain order. It also supports timed actions (e.g., after 3 seconds, generate this command). """ name = 'fake modem' def __init__(self, aio_loop_service: AsyncioEventLoop, command_mode=False, defer_script=False): AsyncioService.__init__(self, aio_loop_service) self.script = None self._deferred_actions = [] self.out_buffer = None self.in_buffer = None self.command_mode = command_mode self.defer_script = defer_script self._defer_event = None # --------------------- Scripting methods --------------------------------- def on_input(self, input: str) -> ReplyAction: """ on_input allows scripting of the modem to react in response to a given expected input. :param input: :return: """ self._allow_states(*ServiceState.halted_states()) reply = ReplyAction(input) self._add_action(reply) return reply def after(self, seconds: int): self._allow_states(*ServiceState.halted_states()) timed = TimedAction(seconds) self._add_action(timed) return timed def load_script(self, lines: str, step: int = 0): self._allow_states(*ServiceState.halted_states()) last = None for line in lines.splitlines(keepends=False): last = self.after(step).output(line) return last def run_scripted_actions(self): """ Begins running the scripted actions (or enters command mode). :return: """ if self._defer_event: self.aio_loop.call_soon_threadsafe(self._defer_event.set) def _add_action(self, action): if not self._add_action_now(action): self._deferred_actions.append(action) def _add_action_now(self, action): if self.script is not None: action.set_loop(self.aio_loop) self.script.put_nowait(action) return True return False # ------------------- SerialDeviceFactory -------------------------------- async def connect(self, aio_loop): self._allow_states(*ServiceState.halted_states()) self.script = Queue(loop=aio_loop) self.in_buffer = StreamReader(loop=aio_loop) self.out_buffer = StreamReader(loop=aio_loop) # Transfer deferred actions. for action in self._deferred_actions: self._add_action_now(action) self._deferred_actions = [] # If the client wishes to delay startup of the scripted actions (e.g. to register an EventStream first), # we set this up here. if self.defer_script: self._defer_event = Event(loop=aio_loop) # We're inside of an asyncio task, but need a synchronous start. Hack up # an asyncio version of it. started = Event(loop=aio_loop) def async_start(): self.sync_start() self.aio_loop.call_soon_threadsafe(started.set) Thread(target=async_start).start() await started.wait() return self.out_buffer, self # ------------------ Fake StreamWriter ------------------------------------ def write(self, data: bytes): self._allow_states(ServiceState.READY) if self.command_mode and self._try_command(data): return self.in_buffer.feed_data(data) def _try_command(self, data: bytes): tokens = data. \ decode(CX930xx_fake.encoding). \ split(' ', maxsplit=1) command = tokens[0] payload = None if len(tokens) == 1 else tokens[1].encode(CX930xx_fake.encoding) # Echoes to output. if command == 'ATECHO' and payload: self.out_buffer.feed_data(payload) return True return False async def drain(self): # Simulates a slow write. await asyncio.sleep(0.1) @property def transport(self): return self def close(self): self.stop() # ------------------- Management methods ---------------------------------- async def _event_loop(self): self._signal_started() # Defers running the script if so requested by the user. if self._defer_event: await self._defer_event.wait() # If the queue is empty and we're not in command mode, # we're done. while (not self.script.empty()) or self.command_mode: event = await self.script.get() await event.process(self) def _graceful_cleanup(self): self.out_buffer.feed_eof() self.in_buffer.feed_eof() # ------------------- Convenience methods --------------------------------- @staticmethod def from_modem_type(modem_type: ModemType, aio_loop_service: AsyncioEventLoop) -> 'ScriptedModem': modem = ScriptedModem(aio_loop_service=aio_loop_service, command_mode=True) for command in modem_type.commands[ModemType.INIT]: modem.on_input(command).reply('OK') return modem