コード例 #1
0
    async def handler(reader: asyncio.StreamReader,
                      writer: asyncio.StreamWriter):

        nonlocal data

        header_size = struct.calcsize('!L')

        while not reader.at_eof():
            try:
                header = await reader.readexactly(header_size)
            except asyncio.IncompleteReadError:
                break

            payload_size = struct.unpack("!L", header)[0]

            try:
                payload = await reader.readexactly(payload_size)
            except asyncio.IncompleteReadError:
                break

            for metric in pickle.loads(payload):
                data.append(metric)

        if len(data) == count:
            event.set()
        writer.close()
        reader.feed_eof()
コード例 #2
0
async def _connect_streams(reader: asyncio.StreamReader,
                           writer: asyncio.StreamWriter,
                           queue: "asyncio.Queue[int]",
                           token: CancelToken) -> None:
    try:
        while not token.triggered:
            if reader.at_eof():
                break

            try:
                size = queue.get_nowait()
            except asyncio.QueueEmpty:
                await asyncio.sleep(0)
                continue
            data = await token.cancellable_wait(reader.readexactly(size))
            writer.write(data)
            queue.task_done()
            await token.cancellable_wait(writer.drain())
    except OperationCancelled:
        pass
    finally:
        writer.write_eof()

    if reader.at_eof():
        reader.feed_eof()
コード例 #3
0
    async def _client_accept(
        self,
        reader: StreamReader,
        writer: StreamWriter,
        read_ahead: bytes = None,
    ) -> None:
        """ Accept new clients and inform the tunnel about connections """
        host, port = writer.get_extra_info("peername")[:2]
        ip = ipaddress.ip_address(host)

        # Block connections using the networks
        if self.block(ip):
            reader.feed_eof()
            writer.close()
            await writer.wait_closed()

            _logger.info("Connection from %s blocked", ip)
            return

        self.connections[ip].hits += 1

        # Create the client object and generate an unique token
        client = Connection(reader, writer, self.protocol,
                            utils.generate_token())
        self.add(client)

        _logger.info("Client %s connected on %s:%s", client.uuid, host, port)

        # Inform the tunnel about the new client
        pkg = package.ClientInitPackage(ip, port, client.token)
        await self.tunnel.tun_write(pkg)

        # Send the buffer read ahead of initialization through the tunnel
        if read_ahead:
            await self.tunnel.tun_data(client.token, read_ahead)

        # Serve data from the client
        while True:
            data = await client.read(self.chunk_size)
            # Client disconnected. Inform the tunnel
            if not data:
                break

            await self.tunnel.tun_data(client.token, data)

        if self.server and self.server.is_serving():
            pkg = package.ClientClosePackage(client.token)
            await self.tunnel.tun_write(pkg)

        await self._disconnect_client(client.token)
コード例 #4
0
ファイル: p2p_server.py プロジェクト: HAOYUatHZ/pyquarkchain
 async def receive_handshake(
     self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
 ) -> None:
     ip, socket, *_ = writer.get_extra_info("peername")
     remote_address = Address(ip, socket)
     if self.peer_pool.chk_dialin_blacklist(remote_address):
         Logger.info_every_n(
             "{} has been blacklisted, refusing connection".format(remote_address),
             100,
         )
         reader.feed_eof()
         writer.close()
     expected_exceptions = (
         TimeoutError,
         PeerConnectionLost,
         HandshakeFailure,
         asyncio.IncompleteReadError,
         HandshakeDisconnectedFailure,
     )
     try:
         await self._receive_handshake(reader, writer)
     except expected_exceptions as e:
         self.logger.debug("Could not complete handshake: %s", e)
         Logger.error_every_n("Could not complete handshake: {}".format(e), 100)
         reader.feed_eof()
         writer.close()
     except OperationCancelled:
         self.logger.error("OperationCancelled")
         reader.feed_eof()
         writer.close()
     except Exception as e:
         self.logger.exception("Unexpected error handling handshake")
         reader.feed_eof()
         writer.close()
コード例 #5
0
ファイル: server.py プロジェクト: hjlee9182/trinity
    async def receive_handshake(self, reader: asyncio.StreamReader,
                                writer: asyncio.StreamWriter) -> None:

        try:
            try:
                await self._receive_handshake(reader, writer)
            except Exception:
                if not reader.at_eof():
                    reader.feed_eof()
                writer.close()
                raise
        except COMMON_RECEIVE_HANDSHAKE_EXCEPTIONS as e:
            self.logger.debug("Could not complete handshake: %s", e)
        except asyncio.CancelledError:
            # This exception should just bubble.
            raise
        except OperationCancelled:
            pass
        except Exception as e:
            self.logger.exception("Unexpected error handling handshake")
コード例 #6
0
ファイル: server.py プロジェクト: marcgarreau/trinity
    async def receive_handshake(
            self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:

        try:
            try:
                await self._receive_handshake(reader, writer)
            except BaseException:
                if not reader.at_eof():
                    reader.feed_eof()
                writer.close()
                raise
        except COMMON_RECEIVE_HANDSHAKE_EXCEPTIONS as e:
            peername = writer.get_extra_info("peername")
            self.logger.debug("Could not complete handshake with %s: %s", peername, e)
        except asyncio.CancelledError:
            # This exception should just bubble.
            raise
        except Exception:
            peername = writer.get_extra_info("peername")
            self.logger.exception("Unexpected error handling handshake with %s", peername)
コード例 #7
0
async def tcp_handler(reader: asyncio.StreamReader,
                      writer: asyncio.StreamWriter):
    addr = writer.get_extra_info('peername')
    log.info("Client connected %r", addr)

    while not reader.at_eof():
        try:
            async with async_timeout.timeout(5):
                line = await reader.readline()
            if line:
                parse_line(line)
        except (asyncio.CancelledError, asyncio.TimeoutError):
            log.info('Client connection closed after timeout')
            break
        except ConnectionResetError:
            log.warning('Client connection reset')
            reader.feed_eof()
            break

    log.info("Client disconnected %r", addr)
コード例 #8
0
class MQTTClientProtocol(FlowControlMixin, asyncio.Protocol):

    def __init__(self, loop, config):
        super().__init__(loop=loop)
        self._loop = loop
        self._config = config

        self._transport = None
        self._write_pending_data_topic = []     # tuple (data, topic)
        self._connected = False

        self._encryptor = cryptor.Cryptor(self._config['password'], self._config['method'])

        self._peername = None

        self._reader_task = None
        self._data_task = None
        self._keepalive_task = None
        self._keepalive_timeout = self._config['timeout']
        self._reader_ready = None
        self._reader_stopped = asyncio.Event(loop=self._loop)
        self._stream_reader = StreamReader(loop=self._loop)
        self._stream_writer = None
        self._reader = None

        self._topic_to_clients = {}

        self._queue = Queue(loop=loop)

    async def create_connection(self):
        try:
            # TODO handle pending task
            transport, protocol = await self._loop.create_connection(lambda: self, self._config['address'], self._config['port'])
        except OSError as e:
            logging.error("{0} when connecting to mqtt server({1}:{2})".format(e, self._config['address'], self._config['port']))
            logging.error("Reconnection will be performed after 5s...")
            await asyncio.sleep(5)     # TODO:retry interval
            self._loop.create_task(self.create_connection())

    def connection_made(self, transport):
        self._peername = transport.get_extra_info('peername')
        self._transport = transport

        self._stream_reader.set_transport(transport)
        self._reader = StreamReaderAdapter(self._stream_reader)
        self._stream_writer = StreamWriter(transport, self,
                                           self._stream_reader,
                                           self._loop)
        self._loop.create_task(self.start())

    def connection_lost(self, exc):
        logging.info("Lost connection with mqtt server{0}".format(self._peername))
        super().connection_lost(exc)
        self._topic_to_clients = {}

        if self._stream_reader is not None:
            if exc is None:
                self._stream_reader.feed_eof()
            else:
                self._stream_reader.set_exception(exc)

        self.stop()

        self.reestablish_connection()

    def reestablish_connection(self):
        self._stream_reader = StreamReader(loop=self._loop)
        self._encryptor = cryptor.Cryptor(self._config['password'], self._config['method'])
        self._loop.call_later(5, lambda: self._loop.create_task(self.create_connection()))

    def data_received(self, data):
        self._stream_reader.feed_data(data)

    def eof_received(self):
        self._stream_reader.feed_eof()

    @asyncio.coroutine
    def consume(self):
        while self._transport is not None:
            packet = yield from self._queue.get()
            if packet is None:
                break

            if self._transport is None:
                break
            yield from self._send_packet(packet)

    @asyncio.coroutine
    def start(self):
        self._reader_ready = asyncio.Event(loop=self._loop)
        self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop)
        yield from self._reader_ready.wait()
        if self._keepalive_timeout:
            self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)

        self._data_task = self._loop.create_task(self.consume())

        # send connect packet
        connect_vh = ConnectVariableHeader(keep_alive=self._keepalive_timeout)
        connect_vh.password_flag = True
        password = self._encryptor.encrypt(self._encryptor.password.encode('utf-8'))
        connect_payload = ConnectPayload(client_id=ConnectPayload.gen_client_id(), password=password)
        connect_packet = ConnectPacket(vh=connect_vh, payload=connect_payload)
        yield from self._do_write(connect_packet)

        logging.info("Creating connection to mqtt server.")

    @asyncio.coroutine
    def stop(self):
        self._connected = False
        if self._keepalive_task:
            self._keepalive_task.cancel()
        self._data_task.cancel()
        logger.debug("waiting for tasks to be stopped")
        if not self._reader_task.done():

            if not self._reader_stopped.is_set():
                self._reader_task.cancel()  # this will cause the reader_loop handle CancelledError
                # yield from asyncio.wait(
                #     [self._reader_stopped.wait()], loop=self._loop)
            else:   # caused by reader_loop break statement
                if self._transport:
                    self._transport.close()
                    self._transport = None

    @asyncio.coroutine
    def _reader_loop(self):
        running_tasks = collections.deque()
        while True:
            try:
                self._reader_ready.set()
                while running_tasks and running_tasks[0].done():
                    running_tasks.popleft()
                if len(running_tasks) > 1:
                    logging.debug("{} Handler running tasks: {}".format(self._peername, len(running_tasks)))

                fixed_header = yield from asyncio.wait_for(
                    MQTTFixedHeader.from_stream(self._reader),
                    self._keepalive_timeout + 10, loop=self._loop)
                if fixed_header:
                    if fixed_header.packet_type == RESERVED_0 or fixed_header.packet_type == RESERVED_15:
                        logging.warning("{} Received reserved packet, which is forbidden: closing connection".format(self._peername))
                        break
                    else:
                        cls = packet_class(fixed_header)
                        packet = yield from cls.from_stream(self._reader, fixed_header=fixed_header)
                        task = None
                        if packet.fixed_header.packet_type == CONNACK:
                            task = ensure_future(self.handle_connack(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == PINGREQ:
                            task = ensure_future(self.handle_pingreq(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == PINGRESP:
                            task = ensure_future(self.handle_pingresp(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == PUBLISH:
                            # task = ensure_future(self.handle_publish(packet), loop=self._loop)
                            self.handle_publish(packet)
                        # elif packet.fixed_header.packet_type == SUBSCRIBE:
                        #     task = ensure_future(self.handle_subscribe(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == UNSUBSCRIBE:
                        #     task = ensure_future(self.handle_unsubscribe(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == SUBACK:
                        #     task = ensure_future(self.handle_suback(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == UNSUBACK:
                        #     task = ensure_future(self.handle_unsuback(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == DISCONNECT:
                            task = ensure_future(self.handle_disconnect(packet), loop=self._loop)
                        else:
                            logging.warning("{} Unhandled packet type: {}".format(self._peername, packet.fixed_header.packet_type))
                        if task:
                            running_tasks.append(task)
                else:
                    logging.debug("{} No more data (EOF received), stopping reader coro".format(self._peername))
                    break
            except MQTTException:
                logging.debug("{} Message discarded".format(self._peername))
            except asyncio.CancelledError:
                # logger.debug("Task cancelled, reader loop ending")
                break
            except asyncio.TimeoutError:
                logging.debug("{} Input stream read timeout".format(self._peername))
                break
            except NoDataException:
                logging.debug("{} No data available".format(self._peername))
            except BaseException as e:
                logging.warning(
                    "{}:{} Unhandled exception in reader coro: {}".format(type(self).__name__, self._peername, e))
                break
        while running_tasks:
            running_tasks.popleft().cancel()
        self._reader_stopped.set()
        logging.debug("{} Reader coro stopped".format(self._peername))
        yield from self.stop()

    def write(self, data: bytes, topic):
        if not self._connected:
            self._write_pending_data_topic.append((data, topic))
            if len(self._write_pending_data_topic) > 50:
                self._write_pending_data_topic.clear()
        else:
            data = self._encryptor.encrypt(data)
            packet = PublishPacket.build(topic, data, None, dup_flag=0, qos=0, retain=0)
            ensure_future(self._do_write(packet), loop=self._loop)

    def write_eof(self, topic):
        packet = PublishPacket.build(topic, b'', None, dup_flag=0, qos=0, retain=1)
        ensure_future(self._do_write(packet), loop=self._loop)

    @asyncio.coroutine
    def _do_write(self, packet):
        yield from self._queue.put(packet)

    @asyncio.coroutine
    def _send_packet(self, packet):
        try:
            yield from packet.to_stream(self._stream_writer)
        except ConnectionResetError:
            return

        self._keepalive_task.cancel()
        self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)

    def handle_write_timeout(self):
        packet = PingReqPacket()
        # TODO: check transport
        self._transport.write(packet.to_bytes())
        self._keepalive_task.cancel()
        self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)

    def handle_read_timeout(self):
        self._loop.create_task(self.stop())

    @asyncio.coroutine
    def handle_connack(self, connack: ConnackPacket):
        if connack.variable_header.return_code == 0:
            self._connected = True
            logging.info("Connection to mqtt server established!")

            if len(self._write_pending_data_topic) > 0:
                self._keepalive_task.cancel()
                for data, topic in self._write_pending_data_topic:
                    data = self._encryptor.encrypt(data)
                    packet = PublishPacket.build(topic, data, None, dup_flag=0, qos=0, retain=0)
                    yield from self._do_write(packet)
                self._write_pending_data_topic = []
                self._keepalive_task = self._loop.call_later(self._keepalive_timeout, self.handle_write_timeout)
        else:
            logging.info("Unable to create connection to mqtt server! Shuting down...")
            self._loop.create_task(self.stop())

    # @asyncio.coroutine
    def handle_publish(self, publish_packet: PublishPacket):
        data = bytes(publish_packet.data)

        server = self._topic_to_clients.get(publish_packet.topic_name, None)
        if server is None:
            logging.info("Received unregistered publish topic({0}) from mqtt server, packet will be ignored.".format(
                publish_packet.topic_name))
        if not publish_packet.retain_flag:    # retain=1 indicate we should close the client connection
            data = self._encryptor.decrypt(data)
            if server is not None:
                server.write(data)
        else:
            if server is not None:
                server.close(force=True)

    @asyncio.coroutine
    def handle_pingresp(self, pingresp: PingRespPacket):
        logging.info("Received PingRespPacket from mqtt server.")

    @asyncio.coroutine
    def handle_pingreq(self, pingreq: PingReqPacket):
        logging.info("Received PingReqPacket from mqtt server, Replying PingResqPacket.")
        ping_resp = PingRespPacket()
        yield from self._do_write(ping_resp)

    def register_client_topic(self, topic, server):
        self._topic_to_clients[topic] = server

    def unregister_client_topic(self, topic):
        self._topic_to_clients.pop(topic, None)
コード例 #9
0
class MQTTServerProtocol(FlowControlMixin, asyncio.Protocol):
    def __init__(self, loop, config):
        super().__init__(loop=loop)
        self._loop = loop
        self._transport = None
        self._encryptor = cryptor.Cryptor(config['password'], config['method'])
        self._topic_to_remote = {}

        self._peername = None

        self._reader_task = None
        self._data_task = None
        self._keepalive_task = None
        self._keepalive_timeout = config['timeout']
        self._reader_ready = None
        self._reader_stopped = asyncio.Event(loop=self._loop)
        self._stream_reader = StreamReader(loop=self._loop)
        self._stream_writer = None
        self._reader = None
        self._approved = False

        self._queue = Queue(loop=loop)

    def connection_made(self, transport):
        self._peername = transport.get_extra_info('peername')
        self._transport = transport

        logging.info("Mqtt client connected from: {}.".format(self._peername))

        self._stream_reader.set_transport(transport)
        self._reader = StreamReaderAdapter(self._stream_reader)
        self._stream_writer = StreamWriter(transport, self,
                                           self._stream_reader, self._loop)
        self._loop.create_task(self.start())

    def connection_lost(self, exc):
        logging.info("Mqtt client connection{} lost.".format(self._peername))
        super().connection_lost(exc)

        if self._stream_reader is not None:
            if exc is None:
                self._stream_reader.feed_eof()
            else:
                self._stream_reader.set_exception(exc)

        self.stop()

    def data_received(self, data):
        self._stream_reader.feed_data(data)

    def eof_received(self):
        self._stream_reader.feed_eof()

    @asyncio.coroutine
    def consume(self):
        while self._transport is not None:
            packet = yield from self._queue.get()
            if packet is None:
                break

            if self._transport is None:
                break
            yield from self._send_packet(packet)

    @asyncio.coroutine
    def start(self):
        self._reader_ready = asyncio.Event(loop=self._loop)
        self._reader_task = asyncio.Task(self._reader_loop(), loop=self._loop)
        yield from self._reader_ready.wait()
        if self._keepalive_timeout:
            self._keepalive_task = self._loop.call_later(
                self._keepalive_timeout, self.handle_write_timeout)

        self._data_task = self._loop.create_task(self.consume())

    @asyncio.coroutine
    def stop(self):
        if self._keepalive_task:
            self._keepalive_task.cancel()
        self._data_task.cancel()
        logger.debug("waiting for tasks to be stopped")
        if not self._reader_task.done():
            if not self._reader_stopped.is_set():
                self._reader_task.cancel(
                )  # this will cause the reader_loop handle CancelledError
                # yield from asyncio.wait(
                #     [self._reader_stopped.wait()], loop=self._loop)
            else:  # caused by reader_loop break statement
                if self._transport:
                    self._transport.close()
                    self._transport = None

                    for topic, remote in self._topic_to_remote.items():
                        remote.close()

    @asyncio.coroutine
    def _reader_loop(self):
        running_tasks = collections.deque()
        while True:
            try:
                self._reader_ready.set()
                while running_tasks and running_tasks[0].done():
                    running_tasks.popleft()
                if len(running_tasks) > 1:
                    logging.debug("{} Handler running tasks: {}".format(
                        self._peername, len(running_tasks)))

                fixed_header = yield from asyncio.wait_for(
                    MQTTFixedHeader.from_stream(self._reader),
                    self._keepalive_timeout + 10,
                    loop=self._loop)
                if fixed_header:
                    if fixed_header.packet_type == RESERVED_0 or fixed_header.packet_type == RESERVED_15:
                        logging.warning(
                            "{} Received reserved packet, which is forbidden: closing connection"
                            .format(self._peername))
                        break
                    else:
                        cls = packet_class(fixed_header)
                        packet = yield from cls.from_stream(
                            self._reader, fixed_header=fixed_header)
                        task = None
                        if packet.fixed_header.packet_type == CONNECT:
                            task = ensure_future(self.handle_connect(packet),
                                                 loop=self._loop)
                        elif packet.fixed_header.packet_type == PINGREQ:
                            task = ensure_future(self.handle_pingreq(packet),
                                                 loop=self._loop)
                        elif packet.fixed_header.packet_type == PINGRESP:
                            task = ensure_future(self.handle_pingresp(packet),
                                                 loop=self._loop)
                        elif packet.fixed_header.packet_type == PUBLISH:
                            # task = ensure_future(self.handle_publish(packet), loop=self._loop)
                            self.handle_publish(packet)
                        # elif packet.fixed_header.packet_type == SUBSCRIBE:
                        #     task = ensure_future(self.handle_subscribe(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == UNSUBSCRIBE:
                        #     task = ensure_future(self.handle_unsubscribe(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == SUBACK:
                        #     task = ensure_future(self.handle_suback(packet), loop=self._loop)
                        # elif packet.fixed_header.packet_type == UNSUBACK:
                        #     task = ensure_future(self.handle_unsuback(packet), loop=self._loop)
                        elif packet.fixed_header.packet_type == DISCONNECT:
                            task = ensure_future(
                                self.handle_disconnect(packet),
                                loop=self._loop)
                        else:
                            # TODO: handle unknow packet type
                            logging.warning(
                                "{} Unhandled packet type: {}".format(
                                    self._peername,
                                    packet.fixed_header.packet_type))
                        if task:
                            running_tasks.append(task)
                else:
                    logging.debug(
                        "{} No more data (EOF received), stopping reader coro".
                        format(self._peername))
                    break
            except MQTTException:
                logging.debug("{} Message discarded".format(self._peername))
            except asyncio.CancelledError:
                # logger.debug("Task cancelled, reader loop ending")
                break
            except asyncio.TimeoutError:
                logging.debug("{} Input stream read timeout".format(
                    self._peername))
                break
            except NoDataException:
                logging.debug("{} No data available".format(self._peername))
            except BaseException as e:
                logging.warning(
                    "{}:{} Unhandled exception in reader coro: {}".format(
                        type(self).__name__, self._peername, e))
                break
        while running_tasks:
            running_tasks.popleft().cancel()
        self._reader_stopped.set()
        logging.debug("{} Reader coro stopped".format(self._peername))
        yield from self.stop()

    # for remote read
    def write(self, data, client_topic):
        data = self._encryptor.encrypt(data)
        packet = PublishPacket.build(client_topic,
                                     data,
                                     None,
                                     dup_flag=0,
                                     qos=0,
                                     retain=0)
        ensure_future(self._do_write(packet), loop=self._loop)

    def _write_eof(self, client_topic):
        packet = PublishPacket.build(client_topic,
                                     b'',
                                     None,
                                     dup_flag=0,
                                     qos=0,
                                     retain=1)
        ensure_future(self._do_write(packet), loop=self._loop)

    @asyncio.coroutine
    def _do_write(self, packet):
        yield from self._queue.put(packet)

    @asyncio.coroutine
    def _send_packet(self, packet):
        yield from packet.to_stream(self._stream_writer)
        self._keepalive_task.cancel()
        self._keepalive_task = self._loop.call_later(self._keepalive_timeout,
                                                     self.handle_write_timeout)

    def handle_write_timeout(self):
        packet = PingReqPacket()
        self._transport.write(packet.to_bytes())
        self._keepalive_task.cancel()
        self._keepalive_task = self._loop.call_later(self._keepalive_timeout,
                                                     self.handle_write_timeout)

    def handle_read_timeout(self):
        self._loop.create_task(self.stop())

    @asyncio.coroutine
    def handle_connect(self, connect: ConnectPacket):
        return_code = 0
        self._approved = True

        password = self._encryptor.decrypt(connect.password)
        password = password.decode('utf-8')
        if password != self._encryptor.password:
            return_code = 4
            self._approved = False
            logging.warning(
                "Invalid ConnectPacket password from mqtt client connection{}!"
                .format(self._peername))

        connack_vh = ConnackVariableHeader(return_code=return_code)
        connack = ConnackPacket(variable_header=connack_vh)
        yield from self._do_write(connack)

        if return_code != 0:
            self._loop.create_task(self.stop())

    # @asyncio.coroutine
    def handle_publish(self, publish_packet: PublishPacket):

        if not self._approved:
            self._loop.create_task(self.stop())
            return

        data = bytes(publish_packet.data)
        remote = self._topic_to_remote.get(publish_packet.topic_name, None)
        if not publish_packet.retain_flag:
            data = self._encryptor.decrypt(data)
            if remote is None:  # we are in STAGE_ADDR
                if not data:
                    self._write_eof(publish_packet.topic_name)
                    return

                header_result = common.parse_header(data)
                if header_result is None:
                    logging.error(
                        "Can not parse header when handling mqtt client({}) connection{}."
                        .format(publish_packet.topic_name, self._peername))
                    self._write_eof(publish_packet.topic_name)
                    return

                addrtype, remote_addr, remote_port, header_length = header_result
                logging.info(
                    "Connecting to remote {}:{} from mqtt client({}) connection{}."
                    .format(common.to_str(remote_addr), remote_port,
                            publish_packet.topic_name, self._peername))

                remote = RelayRemoteProtocol(self._loop, self,
                                             publish_packet.topic_name)
                self._topic_to_remote[publish_packet.topic_name] = remote
                self._loop.create_task(
                    self.create_connection(remote, common.to_str(remote_addr),
                                           remote_port))

                if len(data) > header_length:
                    remote.write(data[header_length:])
            else:  # now in STAGE_STREAM
                remote.write(data)
        else:
            if remote is not None:
                remote.close(force=True)

    @asyncio.coroutine
    def handle_pingresp(self, pingresp: PingRespPacket):
        logging.info("Received PingRespPacket from mqtt client.")

    @asyncio.coroutine
    def handle_pingreq(self, pingreq: PingReqPacket):
        logging.info(
            "Received PingRepPacket from mqtt client, replying PingRespPacket."
        )
        ping_resp = PingRespPacket()
        yield from self._do_write(ping_resp)

    async def create_connection(self, remote, host, port):
        try:
            #TODO handle pending task
            transport, protocol = await self._loop.create_connection(
                lambda: remote, host, port)
        except OSError as e:
            logging.error(
                "{} when creating remote connection to {}:{} from mqtt connection{}."
                .format(e, host, port, self._peername))
            self.remove_topic(remote.client_topic)

    def remove_topic(self, topic):
        if self._transport is not None:
            self._write_eof(topic)
        self._topic_to_remote.pop(topic, None)
コード例 #10
0
ファイル: proxy.py プロジェクト: fkantelberg/socket-proxy
 async def close(self, reader: StreamReader, writer: StreamWriter) -> None:
     """ Close a StreamReader and StreamWriter """
     reader.feed_eof()
     writer.close()
     await writer.wait_closed()
コード例 #11
0
ファイル: server.py プロジェクト: endlessgate/bitLabs
 def disconnect(self, reader: asyncio.StreamReader,
                writer: asyncio.StreamWriter) -> None:
     if not reader.at_eof():
         reader.feed_eof()
     writer.close()
コード例 #12
0
ファイル: server.py プロジェクト: Sovetnikov/livelock
class CommandProtocol(asyncio.Protocol):
    def __init__(self,
                 tcp_keepalive_time=None,
                 tcp_keepalive_interval=None,
                 tcp_keepalive_probes=None,
                 tcp_user_timeout_seconds=None,
                 *args,
                 **kwargs):
        self.transport = None
        self._reader = None
        """
        https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/
        
        http://coryklein.com/tcp/2015/11/25/custom-configuration-of-tcp-socket-keep-alive-timeouts.html
        Keep-Alive Process
        There are three configurable properties that determine how Keep-Alives work. On Linux they are1:
        
        tcp_keepalive_time
        default 7200 seconds
        tcp_keepalive_probes
        default 9
        tcp_keepalive_intvl
        default 75 seconds
        The process works like this:
        
        1. Client opens TCP connection
        2. If the connection is silent for tcp_keepalive_time seconds, send a single empty ACK packet.1
        3. Did the server respond with a corresponding ACK of its own?
            - No
            1. Wait tcp_keepalive_intvl seconds, then send another ACK
            2. Repeat until the number of ACK probes that have been sent equals tcp_keepalive_probes.
            3. If no response has been received at this point, send a RST and terminate the connection.
            - Yes: Return to step 2
        """
        self.tcp_keepalive_time = int(tcp_keepalive_time)
        self.tcp_keepalive_interval = int(tcp_keepalive_interval)
        self.tcp_keepalive_probes = int(tcp_keepalive_probes)
        self.tcp_user_timeout_seconds = int(tcp_user_timeout_seconds)
        super().__init__(*args, **kwargs)

    def data_received(self, data):
        if self.kill_active:
            return
        self._reader.feed_data(data)

    def connection_made(self, transport):
        if self.kill_active:
            return

        sock = transport.get_extra_info('socket')
        set_tcp_keepalive(sock,
                          opts=dict(
                              tcp_keepalive=True,
                              tcp_keepalive_idle=self.tcp_keepalive_time,
                              tcp_keepalive_intvl=self.tcp_keepalive_interval,
                              tcp_keepalive_cnt=self.tcp_keepalive_probes,
                          ))
        # https://eklitzke.org/the-caveats-of-tcp-nodelay
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        if hasattr(socket, 'TCP_USER_TIMEOUT'):
            logger.debug('Setting TCP_USER_TIMEOUT to %s',
                         self.tcp_user_timeout_seconds * 1000)
            sock.setsockopt(socket.SOL_TCP, socket.TCP_USER_TIMEOUT,
                            self.tcp_user_timeout_seconds * 1000)
        self.transport = transport
        self._reader = StreamReader()
        self._reader.set_transport(transport)

        loop = asyncio.get_event_loop()
        loop.create_task(self.receive_commands())
        super().connection_made(transport)

    def connection_lost(self, exc):
        if self.kill_active:
            return
        if self._reader is not None:
            if exc is None:
                self._reader.feed_eof()
            else:
                self._reader.set_exception(exc)
        super().connection_lost(exc)

    def eof_received(self):
        if self.kill_active:
            return
        self._reader.feed_eof()
        return super().eof_received()

    async def _read_int(self):
        if self.kill_active:
            return
        line = await self._reader.readuntil(b'\r\n')
        return int(line.decode().strip())

    async def _read_float(self):
        if self.kill_active:
            return
        line = await self._reader.readuntil(b'\r\n')
        return float(line.decode().strip())

    async def _read_bytes(self):
        if self.kill_active:
            return
        len = await self._read_int()
        line = await self._reader.readexactly(max(2, len + 2))
        if line[-1] != ord(b'\n'):
            raise Exception(r"line[-1] != ord(b'\n')")
        if line[-2] != ord(b'\r'):
            raise Exception(r"line[-2] != ord(b'\r')")
        if len < 0:
            return None
        if len == 0:
            return b''
        return line[:-2]

    async def _read_array(self):
        if self.kill_active:
            return
        len = await self._read_int()
        r = []
        while len:
            c = await self._reader.readexactly(1)
            value = await self._receive_resp(c)
            r.append(value)
            len -= 1
        return r

    async def _receive_resp(self, c):
        if self.kill_active:
            return
        if c == b':':
            return await self._read_int()
        elif c == b'$':
            return await self._read_bytes()
        elif c == b'*':
            return await self._read_array()
        elif c == b',':
            return await self._read_float()
        else:
            raise Exception('Unknown RESP start char %s' % c)

    async def receive_commands(self):
        while True:
            if self.kill_active:
                self.reply_terminating()
                return
            try:
                c = await self._reader.readexactly(1)
                if c in b':*$,':
                    value = await self._receive_resp(c)
                    if not isinstance(value, list):
                        value = [
                            value,
                        ]
                else:
                    command = c + await self._reader.readuntil(b'\r\n')
                    value = [
                        x.strip().encode() for x in command.decode().split(' ')
                    ]
            except (ConnectionAbortedError, ConnectionResetError,
                    IncompleteReadError, TimeoutError) as e:
                # Connection is closed
                self.receive_commands_end(e)
                return
            await self.on_command_received(*value)

    async def on_command_received(self, command):
        raise NotImplemented()

    def reply_terminating(self):
        raise NotImplemented()

    def receive_commands_end(self, exc):
        raise NotImplemented()
コード例 #13
0
ファイル: fakeserial.py プロジェクト: gmega/callblocker
class ScriptedModem(SerialDeviceFactory, Protocol, AsyncioService):
    """
    ScriptedModem emulates the behavior of a serial modem by following a pre-programmed script. It expects
    commands to be issued in a certain order. It also supports timed actions (e.g., after 3 seconds, generate
    this command).
    """

    name = 'fake modem'

    def __init__(self, aio_loop_service: AsyncioEventLoop, command_mode=False, defer_script=False):
        AsyncioService.__init__(self, aio_loop_service)
        self.script = None
        self._deferred_actions = []
        self.out_buffer = None
        self.in_buffer = None

        self.command_mode = command_mode

        self.defer_script = defer_script
        self._defer_event = None

    # --------------------- Scripting methods ---------------------------------

    def on_input(self, input: str) -> ReplyAction:
        """
        on_input allows scripting of the modem to react in response to a given
        expected input.

        :param input:
        :return:
        """
        self._allow_states(*ServiceState.halted_states())
        reply = ReplyAction(input)
        self._add_action(reply)
        return reply

    def after(self, seconds: int):
        self._allow_states(*ServiceState.halted_states())
        timed = TimedAction(seconds)
        self._add_action(timed)
        return timed

    def load_script(self, lines: str, step: int = 0):
        self._allow_states(*ServiceState.halted_states())
        last = None
        for line in lines.splitlines(keepends=False):
            last = self.after(step).output(line)

        return last

    def run_scripted_actions(self):
        """
        Begins running the scripted actions (or enters command mode).
        :return:
        """
        if self._defer_event:
            self.aio_loop.call_soon_threadsafe(self._defer_event.set)

    def _add_action(self, action):
        if not self._add_action_now(action):
            self._deferred_actions.append(action)

    def _add_action_now(self, action):
        if self.script is not None:
            action.set_loop(self.aio_loop)
            self.script.put_nowait(action)
            return True

        return False

    # ------------------- SerialDeviceFactory --------------------------------

    async def connect(self, aio_loop):
        self._allow_states(*ServiceState.halted_states())

        self.script = Queue(loop=aio_loop)
        self.in_buffer = StreamReader(loop=aio_loop)
        self.out_buffer = StreamReader(loop=aio_loop)

        # Transfer deferred actions.
        for action in self._deferred_actions:
            self._add_action_now(action)

        self._deferred_actions = []

        # If the client wishes to delay startup of the scripted actions (e.g. to register an EventStream first),
        # we set this up here.
        if self.defer_script:
            self._defer_event = Event(loop=aio_loop)

        # We're inside of an asyncio task, but need a synchronous start. Hack up
        # an asyncio version of it.
        started = Event(loop=aio_loop)

        def async_start():
            self.sync_start()
            self.aio_loop.call_soon_threadsafe(started.set)

        Thread(target=async_start).start()

        await started.wait()

        return self.out_buffer, self

    # ------------------ Fake StreamWriter ------------------------------------

    def write(self, data: bytes):
        self._allow_states(ServiceState.READY)

        if self.command_mode and self._try_command(data):
            return

        self.in_buffer.feed_data(data)

    def _try_command(self, data: bytes):
        tokens = data. \
            decode(CX930xx_fake.encoding). \
            split(' ', maxsplit=1)

        command = tokens[0]
        payload = None if len(tokens) == 1 else tokens[1].encode(CX930xx_fake.encoding)

        # Echoes to output.
        if command == 'ATECHO' and payload:
            self.out_buffer.feed_data(payload)
            return True

        return False

    async def drain(self):
        # Simulates a slow write.
        await asyncio.sleep(0.1)

    @property
    def transport(self):
        return self

    def close(self):
        self.stop()

    # ------------------- Management methods ----------------------------------

    async def _event_loop(self):
        self._signal_started()
        # Defers running the script if so requested by the user.
        if self._defer_event:
            await self._defer_event.wait()
        # If the queue is empty and we're not in command mode,
        # we're done.
        while (not self.script.empty()) or self.command_mode:
            event = await self.script.get()
            await event.process(self)

    def _graceful_cleanup(self):
        self.out_buffer.feed_eof()
        self.in_buffer.feed_eof()

    # ------------------- Convenience methods ---------------------------------

    @staticmethod
    def from_modem_type(modem_type: ModemType, aio_loop_service: AsyncioEventLoop) -> 'ScriptedModem':
        modem = ScriptedModem(aio_loop_service=aio_loop_service, command_mode=True)
        for command in modem_type.commands[ModemType.INIT]:
            modem.on_input(command).reply('OK')

        return modem