async def test_intercepts_request_iterator_rpc_error_using_write(self):
        for interceptor_class in (_StreamUnaryInterceptorEmpty,
                                  _StreamUnaryInterceptorWithRequestIterator):

            with self.subTest(name=interceptor_class):
                channel = aio.insecure_channel(
                    UNREACHABLE_TARGET, interceptors=[interceptor_class()])
                stub = test_pb2_grpc.TestServiceStub(channel)

                payload = messages_pb2.Payload(body=b'\0' *
                                               _REQUEST_PAYLOAD_SIZE)
                request = messages_pb2.StreamingInputCallRequest(
                    payload=payload)

                call = stub.StreamingInputCall()

                # When there is an error during the write, exception is raised.
                with self.assertRaises(asyncio.InvalidStateError):
                    for _ in range(_NUM_STREAM_REQUESTS):
                        await call.write(request)

                with self.assertRaises(aio.AioRpcError) as exception_context:
                    await call

                self.assertEqual(grpc.StatusCode.UNAVAILABLE,
                                 exception_context.exception.code())
                self.assertTrue(call.done())
                self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())

                await channel.close()
Esempio n. 2
0
    async def test_interceptor_stream_stream(self):
        record = []
        server, stub = await _create_server_stub_pair(
            _LoggingInterceptor('log_stream_stream', record))

        # Prepares the request
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        async def gen():
            for _ in range(_NUM_STREAM_RESPONSES):
                yield request

        # Invokes the actual RPC
        call = stub.StreamingInputCall(gen())

        # Validates the responses
        response = await call
        self.assertIsInstance(response,
                              messages_pb2.StreamingInputCallResponse)
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)

        self.assertEqual(await call.code(), grpc.StatusCode.OK)

        self.assertSequenceEqual([
            'log_stream_stream:intercept_service',
        ], record)
Esempio n. 3
0
    async def test_stream_unary_using_write(self):
        channel = aio.insecure_channel(self._server_target)
        stub = test_pb2_grpc.TestServiceStub(channel)

        # Invokes the actual RPC
        call = stub.StreamingInputCall()

        # Prepares the request
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        # Sends out requests
        for _ in range(_NUM_STREAM_RESPONSES):
            await call.write(request)
        await call.done_writing()

        # Validates the responses
        response = await call
        self.assertIsInstance(response,
                              messages_pb2.StreamingInputCallResponse)
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)

        self.assertEqual(await call.code(), grpc.StatusCode.OK)
        await channel.close()
    async def test_cancel_by_the_interceptor(self):

        class Interceptor(aio.StreamUnaryClientInterceptor):

            async def intercept_stream_unary(self, continuation,
                                             client_call_details,
                                             request_iterator):
                call = await continuation(client_call_details, request_iterator)
                call.cancel()
                return call

        channel = aio.insecure_channel(UNREACHABLE_TARGET,
                                       interceptors=[Interceptor()])
        stub = test_pb2_grpc.TestServiceStub(channel)

        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        call = stub.StreamingInputCall()

        with self.assertRaises(asyncio.InvalidStateError):
            for i in range(_NUM_STREAM_REQUESTS):
                await call.write(request)

        with self.assertRaises(asyncio.CancelledError):
            await call

        self.assertTrue(call.cancelled())
        self.assertTrue(call.done())
        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)

        await channel.close()
    async def test_exception_raised_by_interceptor(self):

        class InterceptorException(Exception):
            pass

        class Interceptor(aio.StreamUnaryClientInterceptor):

            async def intercept_stream_unary(self, continuation,
                                             client_call_details,
                                             request_iterator):
                raise InterceptorException

        channel = aio.insecure_channel(UNREACHABLE_TARGET,
                                       interceptors=[Interceptor()])
        stub = test_pb2_grpc.TestServiceStub(channel)

        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        call = stub.StreamingInputCall()

        with self.assertRaises(InterceptorException):
            for i in range(_NUM_STREAM_REQUESTS):
                await call.write(request)

        with self.assertRaises(InterceptorException):
            await call

        await channel.close()
    async def test_add_done_callback_interceptor_task_finished(self):
        for interceptor_class in (_StreamUnaryInterceptorEmpty,
                                  _StreamUnaryInterceptorWithRequestIterator):

            with self.subTest(name=interceptor_class):
                interceptor = interceptor_class()

                channel = aio.insecure_channel(self._server_target,
                                               interceptors=[interceptor])
                stub = test_pb2_grpc.TestServiceStub(channel)

                payload = messages_pb2.Payload(body=b'\0' *
                                               _REQUEST_PAYLOAD_SIZE)
                request = messages_pb2.StreamingInputCallRequest(
                    payload=payload)

                async def request_iterator():
                    for _ in range(_NUM_STREAM_REQUESTS):
                        yield request

                call = stub.StreamingInputCall(request_iterator())

                response = await call

                validation = inject_callbacks(call)

                await validation

                await channel.close()
    async def test_cancel_while_writing(self):
        # Test cancelation before making any write or after doing at least 1
        for num_writes_before_cancel in (0, 1):
            with self.subTest(name="Num writes before cancel: {}".format(
                    num_writes_before_cancel)):

                channel = aio.insecure_channel(
                    UNREACHABLE_TARGET,
                    interceptors=[_StreamUnaryInterceptorWithRequestIterator()])
                stub = test_pb2_grpc.TestServiceStub(channel)

                payload = messages_pb2.Payload(body=b'\0' *
                                               _REQUEST_PAYLOAD_SIZE)
                request = messages_pb2.StreamingInputCallRequest(
                    payload=payload)

                call = stub.StreamingInputCall()

                with self.assertRaises(asyncio.InvalidStateError):
                    for i in range(_NUM_STREAM_REQUESTS):
                        if i == num_writes_before_cancel:
                            self.assertTrue(call.cancel())
                        await call.write(request)

                with self.assertRaises(asyncio.CancelledError):
                    await call

                self.assertTrue(call.cancelled())
                self.assertTrue(call.done())
                self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)

                await channel.close()
Esempio n. 8
0
def _cancel_after_begin(stub):
  sizes = (27182, 8, 1828, 45904,)
  payloads = (messages_pb2.Payload(body=b'\x00' * size) for size in sizes)
  requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
              for payload in payloads)
  response_future = stub.StreamingInputCall.future(requests)
  response_future.cancel()
  if not response_future.cancelled():
    raise ValueError('expected call to be cancelled')
Esempio n. 9
0
def _client_streaming(stub):
    payload_body_sizes = (27182, 8, 1828, 45904,)
    payloads = (messages_pb2.Payload(body=b'\x00' * size)
                for size in payload_body_sizes)
    requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
                for payload in payloads)
    response = stub.StreamingInputCall(requests)
    if response.aggregated_payload_size != 74922:
        raise ValueError('incorrect size %d!' %
                         response.aggregated_payload_size)
Esempio n. 10
0
async def _perform_stream_unary(stub, wait_for_ready):
    payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
    request = messages_pb2.StreamingInputCallRequest(payload=payload)

    async def gen():
        for _ in range(_NUM_STREAM_RESPONSES):
            yield request

    await stub.StreamingInputCall(gen(),
                                  timeout=test_constants.LONG_TIMEOUT,
                                  wait_for_ready=wait_for_ready)
Esempio n. 11
0
    async def test_write_after_done_writing(self):
        call = self._stub.StreamingInputCall()

        # Prepares the request
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        # Sends out requests
        for _ in range(_NUM_STREAM_RESPONSES):
            await call.write(request)

        # Should be no-op
        await call.done_writing()

        with self.assertRaises(asyncio.InvalidStateError):
            await call.write(messages_pb2.StreamingInputCallRequest())

        response = await call
        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)

        self.assertEqual(await call.code(), grpc.StatusCode.OK)
Esempio n. 12
0
    async def test_normal_iterable_requests(self):
        # Prepares the request
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)
        requests = [request] * _NUM_STREAM_RESPONSES

        # Sends out requests
        call = self._stub.StreamingInputCall(requests)

        # RPC should succeed
        response = await call
        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)

        self.assertEqual(await call.code(), grpc.StatusCode.OK)
Esempio n. 13
0
    async def test_early_cancel_stream_unary(self):
        call = self._stub.StreamingInputCall()

        # Cancels the RPC
        self.assertFalse(call.done())
        self.assertFalse(call.cancelled())
        self.assertTrue(call.cancel())
        self.assertTrue(call.cancelled())

        with self.assertRaises(asyncio.InvalidStateError):
            await call.write(messages_pb2.StreamingInputCallRequest())

        # Should be no-op
        await call.done_writing()

        with self.assertRaises(asyncio.CancelledError):
            await call
Esempio n. 14
0
    async def test_stream_unary(self):
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        async def gen():
            for _ in range(_NUM_STREAM_RESPONSES):
                yield request

        call = self._stub.StreamingInputCall(gen())
        validation = inject_callbacks(call)

        response = await call
        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)
        self.assertEqual(grpc.StatusCode.OK, await call.code())

        await validation
Esempio n. 15
0
    async def test_stream_unary_ok(self):
        call = self._stub.StreamingInputCall()

        # No exception raised and no message swallowed.
        await call.wait_for_connection()

        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        for _ in range(_NUM_STREAM_RESPONSES):
            await call.write(request)
        await call.done_writing()

        response = await call
        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)

        self.assertEqual(await call.code(), grpc.StatusCode.OK)
    async def test_cancel_after_rpc(self):

        interceptor_reached = asyncio.Event()
        wait_for_ever = self.loop.create_future()

        class Interceptor(aio.StreamUnaryClientInterceptor):

            async def intercept_stream_unary(self, continuation,
                                             client_call_details,
                                             request_iterator):
                call = await continuation(client_call_details, request_iterator)
                interceptor_reached.set()
                await wait_for_ever

        channel = aio.insecure_channel(self._server_target,
                                       interceptors=[Interceptor()])
        stub = test_pb2_grpc.TestServiceStub(channel)

        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        call = stub.StreamingInputCall()

        self.assertFalse(call.cancelled())
        self.assertFalse(call.done())

        await interceptor_reached.wait()
        self.assertTrue(call.cancel())

        # When there is an error during the write, exception is raised.
        with self.assertRaises(asyncio.InvalidStateError):
            for _ in range(_NUM_STREAM_REQUESTS):
                await call.write(request)

        with self.assertRaises(asyncio.CancelledError):
            await call

        self.assertTrue(call.cancelled())
        self.assertTrue(call.done())
        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
        self.assertEqual(await call.initial_metadata(), None)
        self.assertEqual(await call.trailing_metadata(), None)
        await channel.close()
Esempio n. 17
0
    async def test_cancel_stream_unary(self):
        call = self._stub.StreamingInputCall()

        # Prepares the request
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        # Sends out requests
        for _ in range(_NUM_STREAM_RESPONSES):
            await call.write(request)

        # Cancels the RPC
        self.assertFalse(call.done())
        self.assertFalse(call.cancelled())
        self.assertTrue(call.cancel())
        self.assertTrue(call.cancelled())

        await call.done_writing()

        with self.assertRaises(asyncio.CancelledError):
            await call
Esempio n. 18
0
    async def test_stream_unary(self):
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        # Calling async API in this thread
        async def gen():
            for _ in range(_NUM_STREAM_RESPONSES):
                yield request

        response = await self._async_stub.StreamingInputCall(gen())
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)

        # Calling sync API in a different thread
        def sync_work() -> None:
            response = self._sync_stub.StreamingInputCall(
                iter([request] * _NUM_STREAM_RESPONSES))
            self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                             response.aggregated_payload_size)

        await self._run_in_another_thread(sync_work)
    async def test_intercepts_prohibit_mixing_style(self):
        channel = aio.insecure_channel(
            self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()])
        stub = test_pb2_grpc.TestServiceStub(channel)

        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        async def request_iterator():
            for _ in range(_NUM_STREAM_REQUESTS):
                yield request

        call = stub.StreamingInputCall(request_iterator())

        with self.assertRaises(grpc._cython.cygrpc.UsageError):
            await call.write(request)

        with self.assertRaises(grpc._cython.cygrpc.UsageError):
            await call.done_writing()

        await channel.close()
Esempio n. 20
0
    async def test_multiple_interceptors_request_iterator(self):
        for interceptor_class in (_StreamUnaryInterceptorEmpty,
                                  _StreamUnaryInterceptorWithRequestIterator):

            with self.subTest(name=interceptor_class):

                interceptors = [interceptor_class(), interceptor_class()]
                channel = aio.insecure_channel(self._server_target,
                                               interceptors=interceptors)
                stub = test_pb2_grpc.TestServiceStub(channel)

                payload = messages_pb2.Payload(body=b'\0' *
                                               _REQUEST_PAYLOAD_SIZE)
                request = messages_pb2.StreamingInputCallRequest(
                    payload=payload)

                async def request_iterator():
                    for _ in range(_NUM_STREAM_REQUESTS):
                        yield request

                call = stub.StreamingInputCall(request_iterator())

                response = await call

                self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
                                 response.aggregated_payload_size)
                self.assertEqual(await call.code(), grpc.StatusCode.OK)
                self.assertEqual(await call.initial_metadata(), aio.Metadata())
                self.assertEqual(await call.trailing_metadata(),
                                 aio.Metadata())
                self.assertEqual(await call.details(), '')
                self.assertEqual(await call.debug_error_string(), '')
                self.assertEqual(call.cancel(), False)
                self.assertEqual(call.cancelled(), False)
                self.assertEqual(call.done(), True)

                for interceptor in interceptors:
                    interceptor.assert_in_final_state(self)

                await channel.close()
Esempio n. 21
0
 async def request_gen():
     for size in payload_body_sizes:
         yield messages_pb2.StreamingInputCallRequest(
             payload=messages_pb2.Payload(body=b'\x00' * size))