async def test_late_cancel_unary_stream(self): """Test cancellation after received all messages.""" # Prepares the request request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE, )) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # After all messages received, it is possible that the final state # is received or on its way. It's basically a data race, so our # expectation here is do not crash :) call.cancel() self.assertIn(await call.code(), [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_error_in_async_generator(self): # Server will pause between responses request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_RESPONSE_INTERVAL_US, )) # We expect the request iterator to receive the exception request_iterator_received_the_exception = asyncio.Event() async def request_iterator(): with self.assertRaises(asyncio.CancelledError): for _ in range(_NUM_STREAM_RESPONSES): yield request await asyncio.sleep(test_constants.SHORT_TIMEOUT) request_iterator_received_the_exception.set() call = self._stub.StreamingInputCall(request_iterator()) # Cancel the RPC after at least one response async def cancel_later(): await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2) call.cancel() cancel_later_task = self.loop.create_task(cancel_later()) # No exceptions here with self.assertRaises(asyncio.CancelledError): await call await request_iterator_received_the_exception.wait() # No failures in the cancel later task! await cancel_later_task
def _custom_metadata(stub): initial_metadata_value = "test_initial_metadata_value" trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), (_TRAILING_METADATA_KEY, trailing_metadata_value)) def _validate_metadata(response): initial_metadata = dict(response.initial_metadata()) if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: raise ValueError('expected initial metadata %s, got %s' % (initial_metadata_value, initial_metadata[_INITIAL_METADATA_KEY])) trailing_metadata = dict(response.trailing_metadata()) if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: raise ValueError('expected trailing metadata %s, got %s' % (trailing_metadata_value, trailing_metadata[_TRAILING_METADATA_KEY])) # Testing with UnaryCall request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=1, payload=messages_pb2.Payload(body=b'\x00')) response_future = stub.UnaryCall.future(request, metadata=metadata) _validate_metadata(response_future) # Testing with FullDuplexCall with _Pipe() as pipe: response_iterator = stub.FullDuplexCall(pipe, metadata=metadata) request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, response_parameters=(messages_pb2.ResponseParameters(size=1), )) pipe.add(request) # Sends the request next(response_iterator) # Causes server to send trailing metadata # Dropping out of the with block closes the pipe _validate_metadata(response_iterator)
async def test_unary_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) # Prepares the request request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) # Invokes the actual RPC call = stub.StreamingOutputCall(request) # Validates the responses response_cnt = 0 async for response in call: response_cnt += 1 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close()
def _cancel_after_first_response(stub): request_response_sizes = ( 31415, 9, 2653, 58979, ) request_payload_sizes = ( 27182, 8, 1828, 45904, ) with _Pipe() as pipe: response_iterator = stub.FullDuplexCall(pipe) response_size = request_response_sizes[0] payload_size = request_payload_sizes[0] request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, response_parameters=(messages_pb2.ResponseParameters( size=response_size), ), payload=messages_pb2.Payload(body=b'\x00' * payload_size)) pipe.add(request) response = next(response_iterator) # We test the contents of `response` in the Ping Pong test - don't check # them here. response_iterator.cancel() try: next(response_iterator) except grpc.RpcError as rpc_error: if rpc_error.code() is not grpc.StatusCode.CANCELLED: raise else: raise ValueError('expected call to be cancelled')
async def test_stream_stream(self): request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) # Calling async API in this thread call = self._async_stub.FullDuplexCall() for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) response = await call.read() assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) await call.done_writing() assert await call.code() == grpc.StatusCode.OK # Calling sync API in a different thread def sync_work() -> None: response_iterator = self._sync_stub.FullDuplexCall(iter([request])) for response in response_iterator: assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) self.assertEqual(grpc.StatusCode.OK, response_iterator.code()) await self._run_in_another_thread(sync_work)
async def test_max_message_length_applied(self): address, server = await start_test_server() async with aio.insecure_channel( address, options=((_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, _MAX_MESSAGE_LENGTH), )) as channel: stub = test_pb2_grpc.TestServiceStub(channel) request = messages_pb2.StreamingOutputCallRequest() # First request will pass request.response_parameters.append( messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH // 2, )) # Second request should fail request.response_parameters.append( messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH * 2, )) call = stub.StreamingOutputCall(request) response = await call.read() self.assertEqual(_MAX_MESSAGE_LENGTH // 2, len(response.payload.body)) with self.assertRaises(aio.AioRpcError) as exception_context: await call.read() rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, rpc_error.code()) self.assertIn(str(_MAX_MESSAGE_LENGTH), rpc_error.details()) self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, await call.code()) await server.stop(None)
async def test_cancel_by_the_interceptor(self): class Interceptor(aio.UnaryStreamClientInterceptor): async def intercept_unary_stream(self, continuation, client_call_details, request): call = await continuation(client_call_details, request) call.cancel() return call channel = aio.insecure_channel(UNREACHABLE_TARGET, interceptors=[Interceptor()]) request = messages_pb2.StreamingOutputCallRequest() stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) with self.assertRaises(asyncio.CancelledError): async for response in call: pass self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) await channel.close()
async def test_stream_stream_using_async_gen(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) # Prepares the request request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) async def gen(): for _ in range(_NUM_STREAM_RESPONSES): yield request # Invokes the actual RPC call = stub.FullDuplexCall(gen()) async for response in call: self.assertIsInstance(response, messages_pb2.StreamingOutputCallResponse) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(grpc.StatusCode.OK, await call.code()) await channel.close()
async def test_timeout(self): call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S) # The error should be raised automatically without any traffic. with self.assertRaises(aio.AioRpcError) as exception_context: await call rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code()) # Prepares the request that stream in a ping-pong manner. _STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() _STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): async def test_cancel(self): # Invokes the actual RPC call = self._stub.FullDuplexCall() for _ in range(_NUM_STREAM_RESPONSES): await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) response = await call.read() self.assertIsInstance(response, messages_pb2.StreamingOutputCallResponse) self.assertEqual(_RESPONSE_PAYLOAD_SIZE,