예제 #1
0
    def test_consume_end_of_stream(self):
        stream = ReceiveStream(1)
        stream.close()

        message_type = Mock()

        assert list(stream.consume(message_type)) == []
예제 #2
0
    def test_write_multiple_messages(self):
        stream = ReceiveStream(1)
        for _ in range(10):
            stream.write(
                b"\x00\x00\x00\x00\x01\xff")  # 10 single byte messages

        assert stream.queue.qsize() == 10
        assert len(stream.buffer) == 0
예제 #3
0
 def test_write_more_bytes_than_one_message(self):
     stream = ReceiveStream(1)
     # incompressed single byte message, followed by two more bytes of /xff
     stream.write(b"\x00\x00\x00\x00\x01\xff\xff\xff")
     # single byte message is queued
     assert stream.queue.get() == (False, b"\xff")
     # following two bytes remain in the buffer
     assert stream.buffer.peek() == b"\xff\xff"
예제 #4
0
    def test_consume_grpc_error(self):
        stream = ReceiveStream(1)
        error = GrpcError("boom", "details", "message")
        stream.queue.put(error)

        message_type = Mock()

        with pytest.raises(GrpcError):
            next(stream.consume(message_type))
예제 #5
0
    def test_consume_uncompressed_message(self):
        stream = ReceiveStream(1)

        message_data = b"x"
        message_type = Mock()
        message = message_type()

        stream.queue.put((False, message_data))
        stream.close()  # close stream so that consume exits

        assert list(stream.consume(message_type)) == [message]
        assert message.ParseFromString.call_args_list == [call(message_data)]
예제 #6
0
    def test_write_to_closed_stream(self):
        stream = ReceiveStream(1)

        assert stream.buffer.empty()
        stream.close()
        stream.write(b"\x00\x00\x00")
        assert stream.buffer.empty()
예제 #7
0
    def send_request(self, request_headers):
        """ Called by the client to invoke a GRPC method.

        Establish a `SendStream` to send the request payload and `ReceiveStream`
        for receiving the eventual response. `SendStream` and `ReceiveStream` are
        returned to the client for providing the request payload and iterating
        over the response.

        Invocations are queued and sent on the next iteration of the event loop.
        """
        stream_id = next(self.counter)

        request_stream = SendStream(stream_id)
        response_stream = ReceiveStream(stream_id)
        self.receive_streams[stream_id] = response_stream
        self.send_streams[stream_id] = request_stream

        request_stream.headers.set(*request_headers)

        self.pending_requests.append(stream_id)

        return request_stream, response_stream
예제 #8
0
    def request_received(self, event):
        """ Receive a GRPC request and pass it to the GrpcServer to fire any
        appropriate entrypoint.

        Establish a `ReceiveStream` to receive the request payload and `SendStream`
        for sending the eventual response.
        """
        super().request_received(event)

        stream_id = event.stream_id

        request_stream = ReceiveStream(stream_id)
        response_stream = SendStream(stream_id)
        self.receive_streams[stream_id] = request_stream
        self.send_streams[stream_id] = response_stream

        request_stream.headers.set(*event.headers, from_wire=True)

        compression = select_algorithm(
            request_stream.headers.get("grpc-accept-encoding"),
            request_stream.headers.get("grpc-encoding"),
        )

        try:
            response_stream.headers.set(
                (":status", "200"),
                ("content-type", "application/grpc+proto"),
                ("grpc-accept-encoding", ",".join(SUPPORTED_ENCODINGS)),
                # TODO support server changing compression later
                ("grpc-encoding", compression),
            )
            response_stream.trailers.set(("grpc-status", "0"))
            self.handle_request(request_stream, response_stream)

        except GrpcError as error:
            response_stream.trailers.set((":status", "200"),
                                         *error.as_headers())
            self.end_stream(stream_id)
예제 #9
0
    def test_write_less_bytes_than_one_message(self):
        stream = ReceiveStream(1)
        stream.write(b"\x00\x00\x00\x01\x00\xff\xff\xff")

        assert stream.queue.empty()
        assert stream.buffer.peek() == b"\x00\x00\x00\x01\x00\xff\xff\xff"
예제 #10
0
    def test_write_less_bytes_than_header(self):
        stream = ReceiveStream(1)
        stream.write(b"\x00\x00\x00")

        assert stream.queue.empty()
        assert stream.buffer.peek() == b"\x00\x00\x00"