Exemple #1
0
async def tcp_send(data,
                   sender: asyncio.StreamWriter,
                   delimiter: bytes = b'\n',
                   timeout=None):
    """
    Get data, convert to str before encoding for simplicity.
    DO NOT PASS BYTES TO DATA! Or will end up receiving b'b'1231''.
    Didn't put type checking for optimization factor.
    """

    data_byte = str(data).encode()
    try:
        data_length = len(data_byte)
        sender.write(str(data_length).encode() + delimiter + data_byte)

    except TypeError:
        msg = "tcp_send: expects"
        if not isinstance(delimiter, bytes):
            print(msg,
                  f"<bytes> for delimiter, got {type(delimiter)} instead.")

        if not isinstance(timeout, Number) and timeout is not None:
            print(msg,
                  f"<numbers> for delimiter, got {type(timeout)} instead.")

        raise

    await asyncio.wait_for(sender.drain(), timeout=timeout)
Exemple #2
0
async def _connect_streams(reader: asyncio.StreamReader,
                           writer: asyncio.StreamWriter,
                           queue: "asyncio.Queue[int]",
                           token: CancelToken) -> None:
    try:
        while not token.triggered:
            if reader.at_eof():
                break

            try:
                size = queue.get_nowait()
            except asyncio.QueueEmpty:
                await asyncio.sleep(0)
                continue
            data = await token.cancellable_wait(reader.readexactly(size))
            writer.write(data)
            queue.task_done()
            await token.cancellable_wait(writer.drain())
    except OperationCancelled:
        pass
    finally:
        writer.write_eof()

    if reader.at_eof():
        reader.feed_eof()
Exemple #3
0
async def connection_loop(execute_rpc: Callable[[Any], Any],
                          reader: asyncio.StreamReader,
                          writer: asyncio.StreamWriter,
                          logger: logging.Logger,
                          cancel_token: CancelToken) -> None:
    # TODO: we should look into using an io.StrinIO here for more efficient
    # writing to the end of the string.
    raw_request = ''
    while True:
        request_bytes = b''
        try:
            request_bytes = await cancel_token.cancellable_wait(reader.readuntil(b'}'))
        except asyncio.LimitOverrunError as e:
            logger.info("Client request was too long. Erasing buffer and restarting...")
            request_bytes = await cancel_token.cancellable_wait(reader.read(e.consumed))
            await cancel_token.cancellable_wait(write_error(
                writer,
                f"reached limit: {e.consumed} bytes, starting with '{request_bytes[:20]!r}'",
            ))
            continue

        raw_request += request_bytes.decode()

        bad_prefix, raw_request = strip_non_json_prefix(raw_request)
        if bad_prefix:
            logger.info("Client started request with non json data: %r", bad_prefix)
            await cancel_token.cancellable_wait(
                write_error(writer, 'Cannot parse json: ' + bad_prefix),
            )

        try:
            request = json.loads(raw_request)
        except json.JSONDecodeError:
            # invalid json request, keep reading data until a valid json is formed
            logger.debug("Invalid JSON, waiting for rest of message: %r", raw_request)
            continue

        # reset the buffer for the next message
        raw_request = ''

        if not request:
            logger.debug("Client sent empty request")
            await cancel_token.cancellable_wait(
                write_error(writer, 'Invalid Request: empty'),
            )
            continue

        try:
            result = await execute_rpc(request)
        except Exception as e:
            logger.exception("Unrecognized exception while executing RPC")
            await cancel_token.cancellable_wait(
                write_error(writer, "unknown failure: " + str(e)),
            )
        else:
            writer.write(result.encode())

        await cancel_token.cancellable_wait(writer.drain())
Exemple #4
0
 def nbd_response(self,
                  writer: StreamWriter,
                  handle: int,
                  error: int = 0,
                  data: bytes = None) -> Generator[Any, None, None]:
     writer.write(struct.pack('>LLQ', self.NBD_REPLY_MAGIC, error, handle))
     if data:
         writer.write(data)
     yield from writer.drain()
Exemple #5
0
def handle_stream(reader: StreamReader, writer: StreamWriter):
    addr = writer.get_extra_info('peername')
    print("Connect from %r" % (addr, ))

    labels_list = load_labels('labels/labels.txt')
    load_graph('model/my_frozen_graph_okyonsei.pb')

    model = load_model()

    count = 0
    detected_count = 0
    last_data_dict = dict()

    try:
        while True:
            count += 1
            data = yield from reader.readexactly(8)
            user_id, body_len = struct.unpack("!II", data)
            data = yield from reader.readexactly(body_len)

            last_data = last_data_dict.get(user_id)
            last_data_dict[user_id] = data

            if last_data is None:
                continue

            print(count)
            wav_data, np_data = convert_pcm_to_wav(last_data + data)

            with open(f'data/{count:06}.wav', 'wb') as fout:
                fout.write(wav_data)

            detected, prediction = run_graph(wav_data, labels_list, 3)

            if detected:
                if detected_count:
                    detected_count = 0
                    continue

                try:
                    start_p = max(np_data.argmax() - 300, 0)
                    name = find_who(model, np_data[start_p:])
                    msg = f'안녕하세요. {name} [정확도: {int(prediction * 100)} %]'
                    msg_encoded = str.encode(msg)
                    header = struct.pack("!II", user_id, len(msg_encoded))
                    writer.write(header + msg_encoded)
                    yield from writer.drain()
                    detected_count += 1
                except Exception:
                    pass
            else:
                detected_count = 0

    except Exception as e:
        print(e)
        writer.close()
Exemple #6
0
def connected_cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
    """
    A callback for connected clients.
    """
    client = writer.get_extra_info("peername")
    sclient = ':'.join(str(_) for _ in client)
    logger.info("Recieved connection from {}:{}".format(*client))
    # Read a subscription message.
    try:
        sub = yield from reader.read(65536)
    except ConnectionResetError:
        rlogger.info("Client {} closed connection".format(sclient))
        return
    if not sub:
        logger.error("Client {} terminated connection abnormally".format(sclient))
        return
    try:
        sub_data = msgpack.unpackb(sub)
    except (msgpack.UnpackException, ValueError) as e:
        logger.error("Recieved unknown subscription message from {}:{}".format(*client))
        yield from writer.drain()
        writer.close()
        return
    # Get the data from the subscription message.
    if not b'queue' in sub_data:
        logger.error("Recieved null queue from {}".format(sclient))
        yield from writer.drain()
        writer.close()
        return
    queue_to_sub = sub_data[b"queue"]
    action = sub_data.get(b"action", 0)
    queue_created = False
    if queue_to_sub not in queues:
        queues[queue_to_sub] = [0, asyncio.Queue()]
        logger.debug("Created queue {}".format(queue_to_sub))
        queue_created = True
    logger.debug("Client {} subscribed to queue {} in mode {} ({})".format(sclient, queue_to_sub,
                                                                           action, "push" if not action else "pull"))
    if action == 0:
        loop.create_task(QueueWaiter(reader, writer, queue_to_sub))
    else:
        loop.create_task(QueueSender(reader, writer, queue_to_sub))
    msgpack.pack({"created": queue_created}, writer)
Exemple #7
0
async def _write_packet(writer: asyncio.StreamWriter,
                        packet: Union[CommandPacket, ResponsePacket]) -> None:
    try:
        data = packet.to_bytes()
        writer.write(data)
        await asyncio.wait_for(writer.drain(), _WRITE_TIMEOUT)
    except asyncio.TimeoutError as exception:
        raise ConnectionFailed() from exception
    except ConnectionError as exception:
        raise ConnectionFailed() from exception
    except OSError as exception:
        raise ConnectionFailed() from exception
Exemple #8
0
    def handle_connection(self, client_reader: asyncio.StreamReader,
                          client_writer: asyncio.StreamWriter,
                          user_dict: dict):
        """
        Example of how to handle a new connection.
        """
        client_writer.write(
            'Greetings.  Lines will be echoed back to you.  Type EOF to exit.\n'
            .encode(ExampleLineEchoNonblockingServer.ENCODING))
        yield from client_writer.drain()  # Non-blocking

        from_client = None
        while from_client != 'EOF':  # Magic kill word from client
            # Discussion: I think it might be a bug (or "feature") in the Python 3.4.3 I'm developing on
            # that there is not a proper error thrown with this readline or the following write.
            # There's a note here: https://github.com/aaugustin/websockets/issues/23
            # In any event, wrapping the read in an asyncio.async(..) solves the problem.
            # Without async(..) when a connection is dropped externally, the code goes into an infinite loop
            # reading an empty string b'' and then writing with the should-be-exception getting swallowed
            # somewhere inside Python's codebase.
            # from_client = yield from client_reader.readline()  # This should be OK, but it's not, strangely
            # from_client = yield from asyncio.shield(client_reader.readline())  # Also works
            # from_client = yield from asyncio.ensure_future(client_reader.readline())  # Use this instead

            from_client = yield from client_reader.readline()

            if from_client == b'':  # Connection probably closed
                from_client = "EOF"
            else:
                from_client = from_client.decode('utf-8').strip()
                print("Recvd: [{}]".format(from_client))

                response = "{}\n".format(from_client).encode(
                    ExampleLineEchoNonblockingServer.ENCODING)
                client_writer.write(response)
                yield from client_writer.drain()  # Non-blocking
Exemple #9
0
 async def send_handshake(
     self,
     writer: asyncio.StreamWriter,
     protocol: TCPProtocolV1,
 ):
     #回复握手
     writer.write(protocol.generate_handshake())
     # writer.drain就是等待sock把缓冲区的数据发送出去
     try:
         await asyncio.wait_for(writer.drain(), protocol.handshake_timeout)
     except asyncio.TimeoutError:
         self.logger.error(
             'sending handshake to remote %s timeout',
             writer.get_extra_info('peername'),
         )
         return False
     return True
Exemple #10
0
    async def __pump_traffic(up_reader: StreamReader, up_writer: StreamWriter,
                             down_reader: StreamReader,
                             down_writer: StreamWriter, *first_chunks):
        async def upstream_channel():
            log.debug('Starting upstream channel.')
            while True:
                log.debug('Waiting for data from client.')
                data = await wait_for(down_reader.read(n=CHUNK_SIZE),
                                      IO_TIMEOUT)
                log.debug('Got %r[...] from client.', data[:16])
                if not data:
                    log.debug('No more data from client.')
                    up_writer.close()
                    break
                log.debug('Sending data to upstream.')
                up_writer.write(data)
                await wait_for(up_writer.drain(), IO_TIMEOUT)

        async def downstream_channel():
            log.debug('Starting downstream channel.')
            while True:
                log.debug('Waiting for data from upstream.')
                data = await wait_for(up_reader.read(n=CHUNK_SIZE), IO_TIMEOUT)
                log.debug('Got %r[...] from upstream.', data[:16])
                if not data:
                    log.debug('No more data from upstream.')
                    break
                log.debug('Sending data to client.')
                down_writer.write(data)
                await wait_for(down_writer.drain(), IO_TIMEOUT)

        log.debug('Sending first chunks.')
        for chunk in first_chunks:
            up_writer.write(chunk)
        await wait_for(up_writer.drain(), IO_TIMEOUT)

        log.debug('Start pumping traffic.')
        await gather(upstream_channel(), downstream_channel())

        log.debug('Client connection has been processed.')
Exemple #11
0
async def handler(reader: StreamReader, writer: StreamWriter) -> None:
    try:
        req_info = await RequestInfo.init(reader=reader)
        if req_info.method == 'POST':
            if req_info.is_multipart:
                pass
                # TODO релиозовать mltipart загрузку https://tools.ietf.org/html/rfc2046
            else:
                await _handler(reader=reader, writer=writer, req_info=req_info)
        else:
            writer.write(
                b'HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n400 Bad Request\r\n\r\n'
            )

    except Exception as e:
        logger.error(f'{e} --- {e.args}')
        writer.write(
            b'HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\nInternal Server Error\r\n\r\n'
        )
    finally:
        await asyncio.wait_for(writer.drain(), 2)
        writer.close()
Exemple #12
0
    def response(self, writer: asyncio.StreamWriter, code: Code, response):
        payload = msgpack.dumps(response)
        length = len(payload)

        writer.write(code.pack() + Types.ulong.pack(length) + payload)
        yield from writer.drain()
Exemple #13
0
    async def _receive_handshake(self, reader: asyncio.StreamReader,
                                 writer: asyncio.StreamWriter) -> None:
        msg = await self.wait(reader.read(ENCRYPTED_AUTH_MSG_LEN),
                              timeout=REPLY_TIMEOUT)

        ip, socket, *_ = writer.get_extra_info("peername")
        remote_address = Address(ip, socket)
        self.logger.debug("Receiving handshake from %s", remote_address)
        got_eip8 = False
        try:
            ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                msg, self.privkey)
        except DecryptionError:
            # Try to decode as EIP8
            got_eip8 = True
            msg_size = big_endian_to_int(msg[:2])
            remaining_bytes = msg_size - ENCRYPTED_AUTH_MSG_LEN + 2
            msg += await self.wait(reader.read(remaining_bytes),
                                   timeout=REPLY_TIMEOUT)
            try:
                ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                    msg, self.privkey)
            except DecryptionError as e:
                self.logger.debug("Failed to decrypt handshake: %s", e)
                return

        initiator_remote = Node(initiator_pubkey, remote_address)
        responder = HandshakeResponder(initiator_remote, self.privkey,
                                       got_eip8, self.cancel_token)

        responder_nonce = secrets.token_bytes(HASH_LEN)
        auth_ack_msg = responder.create_auth_ack_message(responder_nonce)
        auth_ack_ciphertext = responder.encrypt_auth_ack_message(auth_ack_msg)

        # Use the `writer` to send the reply to the remote
        writer.write(auth_ack_ciphertext)
        await self.wait(writer.drain())

        # Call `HandshakeResponder.derive_shared_secrets()` and use return values to create `Peer`
        aes_secret, mac_secret, egress_mac, ingress_mac = responder.derive_secrets(
            initiator_nonce=initiator_nonce,
            responder_nonce=responder_nonce,
            remote_ephemeral_pubkey=ephem_pubkey,
            auth_init_ciphertext=msg,
            auth_ack_ciphertext=auth_ack_ciphertext)

        # Create and register peer in peer_pool
        peer = self.peer_class(
            remote=initiator_remote,
            privkey=self.privkey,
            reader=reader,
            writer=writer,
            aes_secret=aes_secret,
            mac_secret=mac_secret,
            egress_mac=egress_mac,
            ingress_mac=ingress_mac,
            headerdb=self.headerdb,
            network_id=self.network_id,
            inbound=True,
        )

        if self.peer_pool.is_full:
            peer.disconnect(DisconnectReason.too_many_peers)
        else:
            # We use self.wait() here as a workaround for
            # https://github.com/ethereum/py-evm/issues/670.
            await self.wait(self.do_handshake(peer))
Exemple #14
0
 def to_stream(self, writer: asyncio.StreamWriter):
     writer.write(self.to_bytes())
     yield from writer.drain()
Exemple #15
0
 def to_stream(self, writer: asyncio.StreamWriter):
     writer.write(self.to_bytes())
     yield from writer.drain()
     self.protocol_ts = datetime.now()
    async def _receive_handshake(self, reader: asyncio.StreamReader,
                                 writer: asyncio.StreamWriter) -> None:
        msg = await self.wait(reader.read(ENCRYPTED_AUTH_MSG_LEN),
                              timeout=REPLY_TIMEOUT)

        ip, socket, *_ = writer.get_extra_info("peername")
        remote_address = Address(ip, socket)
        self.logger.debug("Receiving handshake from %s", remote_address)
        got_eip8 = False
        try:
            ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                msg, self.privkey)
        except DecryptionError:
            # Try to decode as EIP8
            got_eip8 = True
            msg_size = big_endian_to_int(msg[:2])
            remaining_bytes = msg_size - ENCRYPTED_AUTH_MSG_LEN + 2
            msg += await self.wait(reader.read(remaining_bytes),
                                   timeout=REPLY_TIMEOUT)
            try:
                ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                    msg, self.privkey)
            except DecryptionError as e:
                self.logger.debug("Failed to decrypt handshake: %s", e)
                return

        initiator_remote = Node(initiator_pubkey, remote_address)
        responder = HandshakeResponder(initiator_remote, self.privkey,
                                       got_eip8, self.cancel_token)

        responder_nonce = numpy.random.bytes(HASH_LEN)
        auth_ack_msg = responder.create_auth_ack_message(responder_nonce)
        auth_ack_ciphertext = responder.encrypt_auth_ack_message(auth_ack_msg)

        # Use the `writer` to send the reply to the remote
        writer.write(auth_ack_ciphertext)
        await self.wait(writer.drain())

        # Call `HandshakeResponder.derive_shared_secrets()` and use return values to create `Peer`
        aes_secret, mac_secret, egress_mac, ingress_mac = responder.derive_secrets(
            initiator_nonce=initiator_nonce,
            responder_nonce=responder_nonce,
            remote_ephemeral_pubkey=ephem_pubkey,
            auth_init_ciphertext=msg,
            auth_ack_ciphertext=auth_ack_ciphertext,
        )
        connection = PeerConnection(
            reader=reader,
            writer=writer,
            aes_secret=aes_secret,
            mac_secret=mac_secret,
            egress_mac=egress_mac,
            ingress_mac=ingress_mac,
        )

        # Create and register peer in peer_pool
        peer = self.peer_pool.get_peer_factory().create_peer(
            remote=initiator_remote, connection=connection, inbound=True)

        if self.peer_pool.is_full:
            await peer.disconnect(DisconnectReason.too_many_peers)
            return
        elif not self.peer_pool.is_valid_connection_candidate(peer.remote):
            await peer.disconnect(DisconnectReason.useless_peer)
            return

        total_peers = len(self.peer_pool)
        inbound_peer_count = len([
            peer for peer in self.peer_pool.connected_nodes.values()
            if peer.inbound
        ])
        if total_peers > 1 and inbound_peer_count / total_peers > DIAL_IN_OUT_RATIO:
            # make sure to have at least 1/4 outbound connections
            await peer.disconnect(DisconnectReason.too_many_peers)
        else:
            # We use self.wait() here as a workaround for
            # https://github.com/ethereum/py-evm/issues/670.
            await self.wait(self.do_handshake(peer))
Exemple #17
0
def fdms_session(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
    online = None
    ''':type: (FdmsHeader, FdmsTransaction)'''
    add_on = None
    ''':type: (FdmsHeader, FdmsTransaction)'''
    offline = list()

    writer.write(bytes((ENQ,)))
    yield from writer.drain()

    while True:

        # Get Request
        attempt = 0
        while True:
            try:
                if attempt > 4:
                    return

                request = yield from asyncio.wait_for(read_fdms_packet(reader), timeout=15.0)
                if len(request) == 0:
                    return

                control_byte = request[0]
                if control_byte == STX:
                    lrs = functools.reduce(lambda x, y: x ^ int(y), request[2:-1], int(request[1]))
                    if lrs != request[-1]:
                        raise ValueError('LRS sum')

                    pos, header = parse_header(request)
                    txn = header.create_txn()
                    txn.parse(request[pos:-2])
                    if header.txn_type == FdmsTransactionType.Online.value:
                        if online is None:
                            online = (header, txn)
                        else:
                            add_on = (header, txn)
                    else:
                        offline.append((header, txn))

                    if header.protocol_type == '2':
                        break

                    # Respond with ACK
                    attempt = 0
                    writer.write(bytes((ACK,)))

                elif control_byte == EOT:
                    break

            # Close session
            except asyncio.TimeoutError:
                return

            # Respond with NAK
            except Exception as e:
                logging.getLogger(LOG_NAME).debug('Request error: %s', str(e))
                attempt += 1
                writer.write(bytes((NAK,)))

            yield from writer.drain()

        if online is None:
            return

        # Process Transactions & Send Response
        for txn in offline:
            rs = process_txn(txn)
        offline.clear()

        if add_on is not None:
            process_add_on_txn(online, add_on)
        add_on = None

        rs = process_txn(online)

        # Send Response
        rs_bytes = rs.response()

        if rs.action_code == FdmsActionCode.HostSpecificPoll or rs.action_code == FdmsActionCode.RevisionInquiry:
            writer.write(rs_bytes)
            yield from writer.drain()
        else:
            attempt = 0
            while True:
                if attempt >= 4:
                    return

                writer.write(rs_bytes)
                yield from writer.drain()

                control_byte = 0
                try:
                    while True:
                        rs_head = yield from asyncio.wait_for(reader.read(1), timeout=4.0)
                        if len(rs_head) == 0:
                            return
                        control_byte = rs_head[0] & 0x7f
                        if control_byte == ACK:
                            break
                        elif control_byte == NAK:
                            break
                # Close session
                except asyncio.TimeoutError as e:
                    return

                if control_byte == ACK:
                    break
                else:
                    attempt += 1

            if online[0].wcc in {'B', 'C'}:
                # Send ENQ
                writer.write(bytes((ENQ,)))
                yield from writer.drain()
                continue
            else:
                break

    writer.write(bytes((EOT,)))
    yield from writer.drain()
    if writer.can_write_eof():
        writer.write_eof()
Exemple #18
0
    def __handle_connection(self, client_reader: asyncio.StreamReader,
                            client_writer: asyncio.StreamWriter):
        """
        Handle new TCP connection.

        This function only determines the initial Raptor command and then hands off control
        to an appropriate handler.  The only known Raptor commands at this time are DIRECTORY,
        which gives Raptor a list of known test cases, or the name of a test to try loading.

        :param asyncio.StreamReader client_reader: Associated input stream
        :param asyncio.StreamWriter client_writer: Associated output stream
        """
        # http://www.drdobbs.com/open-source/the-new-asyncio-in-python-34-servers-pro/240168408

        # This timeout handling feels awkward.  I want to wrap EVERYTHING in a timeout.
        # To do this, I need to catch a TimeoutError NOT wrapping the asyncio.wait_for() function
        # directly but rather from the calling context.  In other words, try/except goes around
        # the wrap_timeout_outer() call, but it's _within_ the wrap_timeout_outer() function
        # where we actually have the asyncio.wait_for() function.  The wait_for() function in
        # turn calls the wrap_timeout_inner() function, which actually does the work.

        user_dict = {}

        @asyncio.coroutine
        def wrap_timeout_outer():
            """ One line wrapper """
            if self.connection_timeout is None:
                yield from self.handle_connection(
                    client_reader, client_writer,
                    user_dict)  # Subclasses override this
            else:

                @asyncio.coroutine
                def wrap_timeout_inner():
                    yield from self.handle_connection(client_reader,
                                                      client_writer, user_dict)

                yield from asyncio.wait_for(wrap_timeout_inner(),
                                            self.connection_timeout)

        try:
            # MAIN PROCESSING ENTERS HERE
            yield from wrap_timeout_outer()
        except Exception as e:
            if self.verbosity >= 10:
                print(
                    "Caught exception. Will pass to exception callback handler {}."
                    .format(self.handle_exception),
                    e,
                    file=sys.stderr)
            yield from self.handle_exception(
                e, client_reader, client_writer,
                user_dict)  # Subclasses override this

        finally:
            if self.verbosity >= 10:
                print("Closing connection")
            try:
                yield from client_writer.drain()
            except Exception as e:
                pass
            try:
                client_writer.close()
            except Exception as e:
                print("error closing", file=sys.stderr)
                pass
Exemple #19
0
    def handler(self, reader: StreamReader,
                writer: StreamWriter) -> Generator[Any, None, None]:
        data: Optional[bytes]
        try:
            host, port = writer.get_extra_info("peername")
            version: Optional[Version] = None
            cow_version: Optional[Version] = None
            self.log.info("Incoming connection from %s:%s." % (host, port))

            # Initial handshake
            writer.write(
                struct.pack(">QQH", self.INIT_PASSWD, self.CLISERV_MAGIC,
                            self.NBD_HANDSHAKE_FLAGS))
            yield from writer.drain()

            data = yield from reader.readexactly(4)
            try:
                client_flags = struct.unpack(">L", data)[0]
            except struct.error:
                raise IOError("Handshake failed, disconnecting.")

            # We support both fixed and unfixed new-style negotiation.
            # The specification actually allows a client supporting "fixed" to not set this bit in its reply ("SHOULD").
            fixed = (client_flags & self.NBD_FLAG_FIXED_NEWSTYLE) != 0
            if not fixed:
                self.log.warning(
                    "Client did not signal fixed new-style handshake.")

            client_flags ^= self.NBD_FLAG_FIXED_NEWSTYLE
            if client_flags > 0:
                raise IOError(
                    "Handshake failed, unknown client flags %s, disconnecting."
                    % (client_flags))

            # Negotiation phase
            while True:
                header = yield from reader.readexactly(16)
                try:
                    (magic, opt, length) = struct.unpack(">QLL", header)
                except struct.error:
                    raise IOError(
                        "Negotiation failed: Invalid request, disconnecting.")

                if magic != self.CLISERV_MAGIC:
                    raise IOError("Negotiation failed: Bad magic number: %s." %
                                  magic)

                if length:
                    data = yield from reader.readexactly(length)
                    if len(data) != length:
                        raise IOError(
                            "Negotiation failed: %s bytes expected." % length)
                else:
                    data = None

                self.log.debug("[%s:%s]: opt=%s, length=%s, data=%s" %
                               (host, port, opt, length, data))

                if opt == self.NBD_OPT_EXPORTNAME:
                    if not data:
                        raise IOError(
                            "Negotiation failed: No export name was provided.")

                    version_uid = VersionUid(data.decode("ascii"))
                    if version_uid not in [
                            v.uid for v in self.store.get_versions()
                    ]:
                        if not fixed:
                            raise IOError(
                                "Negotiation failed: Unknown export name.")

                        writer.write(
                            struct.pack(">QLLL", self.NBD_OPT_REPLY_MAGIC, opt,
                                        self.NBD_REP_ERR_UNSUP, 0))
                        yield from writer.drain()
                        continue

                    self.log.info("[%s:%s] Negotiated export: %s." %
                                  (host, port, version_uid.v_string))

                    # We have negotiated a version and it will be used until the client disconnects
                    version = self.store.get_versions(
                        version_uid=version_uid)[0]
                    self.store.open(version)

                    self.log.info("[%s:%s] Version %s has been opened." %
                                  (host, port, version.uid))

                    export_flags = self.NBD_EXPORT_FLAGS
                    if self.read_only:
                        export_flags |= self.NBD_FLAG_READ_ONLY
                        self.log.info("[%s:%s] Export is read only." %
                                      (host, port))
                    else:
                        self.log.info("[%s:%s] Export is read/write." %
                                      (host, port))

                    # In case size is not a multiple of 4096 we extend it to the the maximum support block
                    # size of 4096
                    size = math.ceil(version.size / 4096) * 4096
                    writer.write(struct.pack('>QH', size, export_flags))
                    writer.write(b"\x00" * 124)
                    yield from writer.drain()

                    # Transition to transmission phase
                    break

                elif opt == self.NBD_OPT_LIST:
                    # Don't use version as a loop variable so we don't conflict with the outer scope usage
                    for list_version in self.store.get_versions():
                        list_version_encoded = list_version.uid.v_string.encode(
                            "ascii")
                        writer.write(
                            struct.pack(">QLLL", self.NBD_OPT_REPLY_MAGIC, opt,
                                        self.NBD_REP_SERVER,
                                        len(list_version_encoded) + 4))
                        writer.write(
                            struct.pack(">L", len(list_version_encoded)))
                        writer.write(list_version_encoded)
                        yield from writer.drain()

                    writer.write(
                        struct.pack(">QLLL", self.NBD_OPT_REPLY_MAGIC, opt,
                                    self.NBD_REP_ACK, 0))
                    yield from writer.drain()

                elif opt == self.NBD_OPT_ABORT:
                    writer.write(
                        struct.pack(">QLLL", self.NBD_OPT_REPLY_MAGIC, opt,
                                    self.NBD_REP_ACK, 0))
                    yield from writer.drain()

                    raise _NbdServerAbortedNegotiationError()
                else:
                    # We don't support any other option
                    if not fixed:
                        raise IOError("Unsupported option: %s." % (opt))

                    writer.write(
                        struct.pack(">QLLL", self.NBD_OPT_REPLY_MAGIC, opt,
                                    self.NBD_REP_ERR_UNSUP, 0))
                    yield from writer.drain()

            # Transmission phase
            while True:
                header = yield from reader.readexactly(28)
                try:
                    (magic, cmd, handle, offset,
                     length) = struct.unpack(">LLQQL", header)
                except struct.error:
                    raise IOError("Invalid request, disconnecting.")

                if magic != self.NBD_REQUEST_MAGIC:
                    raise IOError("Bad magic number, disconnecting.")

                cmd_flags = cmd & self.NBD_CMD_MASK_FLAGS
                cmd = cmd & self.NBD_CMD_MASK_COMMAND

                self.log.debug(
                    "[%s:%s]: cmd=%s, cmd_flags=%s, handle=%s, offset=%s, len=%s"
                    % (host, port, cmd, cmd_flags, handle, offset, length))

                # We don't support any command flags
                if cmd_flags != 0:
                    yield from self.nbd_response(writer,
                                                 handle,
                                                 error=self.EINVAL)
                    continue

                if cmd == self.NBD_CMD_DISC:
                    self.log.info("[%s:%s] disconnecting" % (host, port))
                    break

                elif cmd == self.NBD_CMD_WRITE:
                    data = yield from reader.readexactly(length)
                    if len(data) != length:
                        raise IOError("%s bytes expected, disconnecting." %
                                      length)

                    if self.read_only:
                        yield from self.nbd_response(writer,
                                                     handle,
                                                     error=self.EPERM)
                        continue

                    if not cow_version:
                        cow_version = self.store.get_cow_version(version)
                    try:
                        self.store.write(cow_version, offset, data)
                    except Exception as exception:
                        self.log.error(
                            "[%s:%s] NBD_CMD_WRITE: %s\n%s." %
                            (host, port, exception, traceback.format_exc()))
                        yield from self.nbd_response(writer,
                                                     handle,
                                                     error=self.EIO)
                        continue

                    yield from self.nbd_response(writer, handle)

                elif cmd == self.NBD_CMD_READ:
                    try:
                        data = self.store.read(version, cow_version, offset,
                                               length)
                    except Exception as exception:
                        self.log.error(
                            "[%s:%s] NBD_CMD_READ: %s\n%s." %
                            (host, port, exception, traceback.format_exc()))
                        yield from self.nbd_response(writer,
                                                     handle,
                                                     error=self.EIO)
                        continue

                    yield from self.nbd_response(writer, handle, data=data)

                elif cmd == self.NBD_CMD_FLUSH:
                    # Return success right away when we're read only or when we haven't written anything yet.
                    if self.read_only or not cow_version:
                        yield from self.nbd_response(writer, handle)
                        continue

                    try:
                        self.store.flush(cow_version)
                    except Exception as exception:
                        self.log.error(
                            "[%s:%s] NBD_CMD_FLUSH: %s\n%s." %
                            (host, port, exception, traceback.format_exc()))
                        yield from self.nbd_response(writer,
                                                     handle,
                                                     error=self.EIO)
                        continue

                    yield from self.nbd_response(writer, handle)

                else:
                    self.log.warning("[%s:%s] Unknown cmd %s, ignoring." %
                                     (host, port, cmd))
                    yield from self.nbd_response(writer,
                                                 handle,
                                                 error=self.EINVAL)
                    continue

        except _NbdServerAbortedNegotiationError:
            self.log.info("[%s:%s] Client aborted negotiation." % (host, port))

        except (asyncio.IncompleteReadError, IOError) as exception:
            self.log.error("[%s:%s] %s" % (host, port, exception))

        finally:
            if cow_version:
                self.store.fixate(cow_version)
            if version:
                self.store.close(version)
            writer.close()
Exemple #20
0
def _write_tcp_packet(packet: Packet, writer: aio.StreamWriter) -> None:
    packet = packet.to_bytes()
    writer.write(struct.pack('>H', len(packet)) + packet)
    writer.drain()
Exemple #21
0
    async def receive_connection(cls, reader: asyncio.StreamReader,
                                 writer: asyncio.StreamWriter,
                                 private_key: datatypes.PrivateKey,
                                 token: CancelToken) -> TransportAPI:
        try:
            msg = await token.cancellable_wait(
                reader.readexactly(ENCRYPTED_AUTH_MSG_LEN),
                timeout=REPLY_TIMEOUT,
            )
        except asyncio.IncompleteReadError as err:
            raise HandshakeFailure from err

        try:
            ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                msg,
                private_key,
            )
        except DecryptionError as non_eip8_err:
            # Try to decode as EIP8
            msg_size = big_endian_to_int(msg[:2])
            remaining_bytes = msg_size - ENCRYPTED_AUTH_MSG_LEN + 2

            try:
                msg += await token.cancellable_wait(
                    reader.readexactly(remaining_bytes),
                    timeout=REPLY_TIMEOUT,
                )
            except asyncio.IncompleteReadError as err:
                raise HandshakeFailure from err

            try:
                ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                    msg,
                    private_key,
                )
            except DecryptionError as eip8_err:
                raise HandshakeFailure(
                    f"Failed to decrypt both EIP8 handshake: {eip8_err}  and "
                    f"non-EIP8 handshake: {non_eip8_err}")
            else:
                got_eip8 = True
        else:
            got_eip8 = False

        peername = writer.get_extra_info("peername")
        if peername is None:
            socket = writer.get_extra_info("socket")
            sockname = writer.get_extra_info("sockname")
            raise HandshakeFailure(
                "Received incoming connection with no remote information:"
                f"socket={repr(socket)}  sockname={sockname}")

        ip, socket, *_ = peername
        remote_address = Address(ip, socket)

        cls.logger.debug("Receiving handshake from %s", remote_address)

        initiator_remote = Node(initiator_pubkey, remote_address)

        responder = HandshakeResponder(initiator_remote, private_key, got_eip8,
                                       token)

        responder_nonce = secrets.token_bytes(HASH_LEN)

        auth_ack_msg = responder.create_auth_ack_message(responder_nonce)
        auth_ack_ciphertext = responder.encrypt_auth_ack_message(auth_ack_msg)

        # Use the `writer` to send the reply to the remote
        writer.write(auth_ack_ciphertext)
        await token.cancellable_wait(writer.drain())

        # Call `HandshakeResponder.derive_shared_secrets()` and use return values to create `Peer`
        aes_secret, mac_secret, egress_mac, ingress_mac = responder.derive_secrets(
            initiator_nonce=initiator_nonce,
            responder_nonce=responder_nonce,
            remote_ephemeral_pubkey=ephem_pubkey,
            auth_init_ciphertext=msg,
            auth_ack_ciphertext=auth_ack_ciphertext)

        transport = cls(
            remote=initiator_remote,
            private_key=private_key,
            reader=reader,
            writer=writer,
            aes_secret=aes_secret,
            mac_secret=mac_secret,
            egress_mac=egress_mac,
            ingress_mac=ingress_mac,
        )
        return transport
Exemple #22
0
def write_with_header(writer: asyncio.StreamWriter, packet: pb.Packet):
    tpl = struct.Struct('>I')
    data = packet.encode_to_bytes()
    writer.write(tpl.pack(len(data)))
    writer.write(data)
    yield from writer.drain()