def test_invoke_unary_rpc_keep_open(self) -> None: method = self._service.SomeUnary.method payload_1 = method.response_type(payload='-_-') payload_2 = method.response_type(payload='0_o') self._enqueue_response(1, method, Status.ABORTED, payload_1) replies: list = [] enqueue_replies = lambda _, reply: replies.append(reply) self._service.SomeUnary.invoke(method.request_type(magic_number=6), enqueue_replies, enqueue_replies, keep_open=True) self.assertEqual([payload_1, Status.ABORTED], replies) # Send another packet and make sure it is processed even though the RPC # terminated. self._client.process_packet( packet_pb2.RpcPacket( type=packet_pb2.PacketType.RESPONSE, channel_id=1, service_id=method.service.id, method_id=method.id, status=Status.OK.value, payload=payload_2.SerializeToString()).SerializeToString()) self.assertEqual([payload_1, Status.ABORTED, payload_2, Status.OK], replies)
def _enqueue_response(self, channel_id: int, method=None, status: Status = Status.OK, response=b'', *, ids: Tuple[int, int] = None, process_status=Status.OK): if method: assert ids is None service_id, method_id = method.service.id, method.id else: assert ids is not None and method is None service_id, method_id = ids if isinstance(response, bytes): payload = response else: payload = response.SerializeToString() self._next_packets.append( (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE, channel_id=channel_id, service_id=service_id, method_id=method_id, status=status.value, payload=payload).SerializeToString(), process_status))
def encode_cancel(rpc: tuple) -> bytes: channel, service, method = _ids(rpc) return packet_pb2.RpcPacket( type=packet_pb2.PacketType.CANCEL_SERVER_STREAM, channel_id=channel, service_id=service, method_id=method).SerializeToString()
def encode_response(rpc: tuple, response: message.Message) -> bytes: channel, service, method = _ids(rpc) return packet_pb2.RpcPacket( type=packet_pb2.PacketType.RESPONSE, channel_id=channel, service_id=service, method_id=method, payload=response.SerializeToString()).SerializeToString()
def _enqueue_error(self, channel_id: int, method, status: Status, process_status=Status.OK): self._next_packets.append( (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR, channel_id=channel_id, service_id=method.service.id, method_id=method.id, status=status.value).SerializeToString(), process_status))
def encode_client_error(packet, status: Status) -> bytes: return packet_pb2.RpcPacket(type=packet_pb2.PacketType.CLIENT_ERROR, channel_id=packet.channel_id, service_id=packet.service_id, method_id=packet.method_id, status=status.value).SerializeToString()
def decode(data: bytes): packet = packet_pb2.RpcPacket() packet.MergeFromString(data) return packet