async def test_intercepts_request_iterator_rpc_error_using_write(self): for interceptor_class in (_StreamUnaryInterceptorEmpty, _StreamUnaryInterceptorWithRequestIterator): with self.subTest(name=interceptor_class): channel = aio.insecure_channel( UNREACHABLE_TARGET, interceptors=[interceptor_class()]) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest( payload=payload) call = stub.StreamingInputCall() # When there is an error during the write, exception is raised. with self.assertRaises(asyncio.InvalidStateError): for _ in range(_NUM_STREAM_REQUESTS): await call.write(request) with self.assertRaises(aio.AioRpcError) as exception_context: await call self.assertEqual(grpc.StatusCode.UNAVAILABLE, exception_context.exception.code()) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) await channel.close()
async def test_interceptor_stream_stream(self): record = [] server, stub = await _create_server_stub_pair( _LoggingInterceptor('log_stream_stream', record)) # Prepares the request payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def gen(): for _ in range(_NUM_STREAM_RESPONSES): yield request # Invokes the actual RPC call = stub.StreamingInputCall(gen()) # Validates the responses response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertSequenceEqual([ 'log_stream_stream:intercept_service', ], record)
async def test_stream_unary_using_write(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) # Invokes the actual RPC call = stub.StreamingInputCall() # Prepares the request payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Sends out requests for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) await call.done_writing() # Validates the responses response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close()
async def test_cancel_by_the_interceptor(self): class Interceptor(aio.StreamUnaryClientInterceptor): async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): call = await continuation(client_call_details, request_iterator) call.cancel() return call channel = aio.insecure_channel(UNREACHABLE_TARGET, interceptors=[Interceptor()]) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) call = stub.StreamingInputCall() with self.assertRaises(asyncio.InvalidStateError): for i in range(_NUM_STREAM_REQUESTS): await call.write(request) with self.assertRaises(asyncio.CancelledError): await call self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) await channel.close()
async def test_exception_raised_by_interceptor(self): class InterceptorException(Exception): pass class Interceptor(aio.StreamUnaryClientInterceptor): async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): raise InterceptorException channel = aio.insecure_channel(UNREACHABLE_TARGET, interceptors=[Interceptor()]) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) call = stub.StreamingInputCall() with self.assertRaises(InterceptorException): for i in range(_NUM_STREAM_REQUESTS): await call.write(request) with self.assertRaises(InterceptorException): await call await channel.close()
async def test_add_done_callback_interceptor_task_finished(self): for interceptor_class in (_StreamUnaryInterceptorEmpty, _StreamUnaryInterceptorWithRequestIterator): with self.subTest(name=interceptor_class): interceptor = interceptor_class() channel = aio.insecure_channel(self._server_target, interceptors=[interceptor]) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest( payload=payload) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): yield request call = stub.StreamingInputCall(request_iterator()) response = await call validation = inject_callbacks(call) await validation await channel.close()
async def test_cancel_while_writing(self): # Test cancelation before making any write or after doing at least 1 for num_writes_before_cancel in (0, 1): with self.subTest(name="Num writes before cancel: {}".format( num_writes_before_cancel)): channel = aio.insecure_channel( UNREACHABLE_TARGET, interceptors=[_StreamUnaryInterceptorWithRequestIterator()]) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest( payload=payload) call = stub.StreamingInputCall() with self.assertRaises(asyncio.InvalidStateError): for i in range(_NUM_STREAM_REQUESTS): if i == num_writes_before_cancel: self.assertTrue(call.cancel()) await call.write(request) with self.assertRaises(asyncio.CancelledError): await call self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) await channel.close()
def _cancel_after_begin(stub): sizes = (27182, 8, 1828, 45904,) payloads = (messages_pb2.Payload(body=b'\x00' * size) for size in sizes) requests = (messages_pb2.StreamingInputCallRequest(payload=payload) for payload in payloads) response_future = stub.StreamingInputCall.future(requests) response_future.cancel() if not response_future.cancelled(): raise ValueError('expected call to be cancelled')
def _client_streaming(stub): payload_body_sizes = (27182, 8, 1828, 45904,) payloads = (messages_pb2.Payload(body=b'\x00' * size) for size in payload_body_sizes) requests = (messages_pb2.StreamingInputCallRequest(payload=payload) for payload in payloads) response = stub.StreamingInputCall(requests) if response.aggregated_payload_size != 74922: raise ValueError('incorrect size %d!' % response.aggregated_payload_size)
async def _perform_stream_unary(stub, wait_for_ready): payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def gen(): for _ in range(_NUM_STREAM_RESPONSES): yield request await stub.StreamingInputCall(gen(), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready)
async def test_write_after_done_writing(self): call = self._stub.StreamingInputCall() # Prepares the request payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Sends out requests for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) # Should be no-op await call.done_writing() with self.assertRaises(asyncio.InvalidStateError): await call.write(messages_pb2.StreamingInputCallRequest()) response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_normal_iterable_requests(self): # Prepares the request payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) requests = [request] * _NUM_STREAM_RESPONSES # Sends out requests call = self._stub.StreamingInputCall(requests) # RPC should succeed response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_early_cancel_stream_unary(self): call = self._stub.StreamingInputCall() # Cancels the RPC self.assertFalse(call.done()) self.assertFalse(call.cancelled()) self.assertTrue(call.cancel()) self.assertTrue(call.cancelled()) with self.assertRaises(asyncio.InvalidStateError): await call.write(messages_pb2.StreamingInputCallRequest()) # Should be no-op await call.done_writing() with self.assertRaises(asyncio.CancelledError): await call
async def test_stream_unary(self): payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def gen(): for _ in range(_NUM_STREAM_RESPONSES): yield request call = self._stub.StreamingInputCall(gen()) validation = inject_callbacks(call) response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(grpc.StatusCode.OK, await call.code()) await validation
async def test_stream_unary_ok(self): call = self._stub.StreamingInputCall() # No exception raised and no message swallowed. await call.wait_for_connection() payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) await call.done_writing() response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_cancel_after_rpc(self): interceptor_reached = asyncio.Event() wait_for_ever = self.loop.create_future() class Interceptor(aio.StreamUnaryClientInterceptor): async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): call = await continuation(client_call_details, request_iterator) interceptor_reached.set() await wait_for_ever channel = aio.insecure_channel(self._server_target, interceptors=[Interceptor()]) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) call = stub.StreamingInputCall() self.assertFalse(call.cancelled()) self.assertFalse(call.done()) await interceptor_reached.wait() self.assertTrue(call.cancel()) # When there is an error during the write, exception is raised. with self.assertRaises(asyncio.InvalidStateError): for _ in range(_NUM_STREAM_REQUESTS): await call.write(request) with self.assertRaises(asyncio.CancelledError): await call self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.trailing_metadata(), None) await channel.close()
async def test_cancel_stream_unary(self): call = self._stub.StreamingInputCall() # Prepares the request payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Sends out requests for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) # Cancels the RPC self.assertFalse(call.done()) self.assertFalse(call.cancelled()) self.assertTrue(call.cancel()) self.assertTrue(call.cancelled()) await call.done_writing() with self.assertRaises(asyncio.CancelledError): await call
async def test_stream_unary(self): payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Calling async API in this thread async def gen(): for _ in range(_NUM_STREAM_RESPONSES): yield request response = await self._async_stub.StreamingInputCall(gen()) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) # Calling sync API in a different thread def sync_work() -> None: response = self._sync_stub.StreamingInputCall( iter([request] * _NUM_STREAM_RESPONSES)) self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) await self._run_in_another_thread(sync_work)
async def test_intercepts_prohibit_mixing_style(self): channel = aio.insecure_channel( self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()]) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): yield request call = stub.StreamingInputCall(request_iterator()) with self.assertRaises(grpc._cython.cygrpc.UsageError): await call.write(request) with self.assertRaises(grpc._cython.cygrpc.UsageError): await call.done_writing() await channel.close()
async def test_multiple_interceptors_request_iterator(self): for interceptor_class in (_StreamUnaryInterceptorEmpty, _StreamUnaryInterceptorWithRequestIterator): with self.subTest(name=interceptor_class): interceptors = [interceptor_class(), interceptor_class()] channel = aio.insecure_channel(self._server_target, interceptors=interceptors) stub = test_pb2_grpc.TestServiceStub(channel) payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest( payload=payload) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): yield request call = stub.StreamingInputCall(request_iterator()) response = await call self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, response.aggregated_payload_size) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) self.assertEqual(await call.details(), '') self.assertEqual(await call.debug_error_string(), '') self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) for interceptor in interceptors: interceptor.assert_in_final_state(self) await channel.close()
async def request_gen(): for size in payload_body_sizes: yield messages_pb2.StreamingInputCallRequest( payload=messages_pb2.Payload(body=b'\x00' * size))