Ejemplo n.º 1
0
    async def test_late_cancel_unary_stream(self):
        """Test cancellation after received all messages."""
        # Prepares the request
        request = messages_pb2.StreamingOutputCallRequest()
        for _ in range(_NUM_STREAM_RESPONSES):
            request.response_parameters.append(
                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE, ))

        # Invokes the actual RPC
        call = self._stub.StreamingOutputCall(request)

        for _ in range(_NUM_STREAM_RESPONSES):
            response = await call.read()
            self.assertIs(type(response),
                          messages_pb2.StreamingOutputCallResponse)
            self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
                             len(response.payload.body))

        # After all messages received, it is possible that the final state
        # is received or on its way. It's basically a data race, so our
        # expectation here is do not crash :)
        call.cancel()
        self.assertIn(await call.code(),
                      [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
Ejemplo n.º 2
0
    async def test_error_in_async_generator(self):
        # Server will pause between responses
        request = messages_pb2.StreamingOutputCallRequest()
        request.response_parameters.append(
            messages_pb2.ResponseParameters(
                size=_RESPONSE_PAYLOAD_SIZE,
                interval_us=_RESPONSE_INTERVAL_US,
            ))

        # We expect the request iterator to receive the exception
        request_iterator_received_the_exception = asyncio.Event()

        async def request_iterator():
            with self.assertRaises(asyncio.CancelledError):
                for _ in range(_NUM_STREAM_RESPONSES):
                    yield request
                    await asyncio.sleep(test_constants.SHORT_TIMEOUT)
            request_iterator_received_the_exception.set()

        call = self._stub.StreamingInputCall(request_iterator())

        # Cancel the RPC after at least one response
        async def cancel_later():
            await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
            call.cancel()

        cancel_later_task = self.loop.create_task(cancel_later())

        # No exceptions here
        with self.assertRaises(asyncio.CancelledError):
            await call

        await request_iterator_received_the_exception.wait()

        # No failures in the cancel later task!
        await cancel_later_task
Ejemplo n.º 3
0
def _custom_metadata(stub):
    initial_metadata_value = "test_initial_metadata_value"
    trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
    metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
                (_TRAILING_METADATA_KEY, trailing_metadata_value))

    def _validate_metadata(response):
        initial_metadata = dict(response.initial_metadata())
        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
            raise ValueError('expected initial metadata %s, got %s' %
                             (initial_metadata_value,
                              initial_metadata[_INITIAL_METADATA_KEY]))
        trailing_metadata = dict(response.trailing_metadata())
        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
            raise ValueError('expected trailing metadata %s, got %s' %
                             (trailing_metadata_value,
                              trailing_metadata[_TRAILING_METADATA_KEY]))

    # Testing with UnaryCall
    request = messages_pb2.SimpleRequest(
        response_type=messages_pb2.COMPRESSABLE,
        response_size=1,
        payload=messages_pb2.Payload(body=b'\x00'))
    response_future = stub.UnaryCall.future(request, metadata=metadata)
    _validate_metadata(response_future)

    # Testing with FullDuplexCall
    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
        request = messages_pb2.StreamingOutputCallRequest(
            response_type=messages_pb2.COMPRESSABLE,
            response_parameters=(messages_pb2.ResponseParameters(size=1), ))
        pipe.add(request)  # Sends the request
        next(response_iterator)  # Causes server to send trailing metadata
    # Dropping out of the with block closes the pipe
    _validate_metadata(response_iterator)
Ejemplo n.º 4
0
    async def test_unary_stream(self):
        channel = aio.insecure_channel(self._server_target)
        stub = test_pb2_grpc.TestServiceStub(channel)

        # Prepares the request
        request = messages_pb2.StreamingOutputCallRequest()
        for _ in range(_NUM_STREAM_RESPONSES):
            request.response_parameters.append(
                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))

        # Invokes the actual RPC
        call = stub.StreamingOutputCall(request)

        # Validates the responses
        response_cnt = 0
        async for response in call:
            response_cnt += 1
            self.assertIs(type(response),
                          messages_pb2.StreamingOutputCallResponse)
            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))

        self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
        self.assertEqual(await call.code(), grpc.StatusCode.OK)
        await channel.close()
Ejemplo n.º 5
0
def _cancel_after_first_response(stub):
    request_response_sizes = (
        31415,
        9,
        2653,
        58979,
    )
    request_payload_sizes = (
        27182,
        8,
        1828,
        45904,
    )
    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe)

        response_size = request_response_sizes[0]
        payload_size = request_payload_sizes[0]
        request = messages_pb2.StreamingOutputCallRequest(
            response_type=messages_pb2.COMPRESSABLE,
            response_parameters=(messages_pb2.ResponseParameters(
                size=response_size), ),
            payload=messages_pb2.Payload(body=b'\x00' * payload_size))
        pipe.add(request)
        response = next(response_iterator)
        # We test the contents of `response` in the Ping Pong test - don't check
        # them here.
        response_iterator.cancel()

        try:
            next(response_iterator)
        except grpc.RpcError as rpc_error:
            if rpc_error.code() is not grpc.StatusCode.CANCELLED:
                raise
        else:
            raise ValueError('expected call to be cancelled')
Ejemplo n.º 6
0
    async def test_stream_stream(self):
        request = messages_pb2.StreamingOutputCallRequest()
        request.response_parameters.append(
            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))

        # Calling async API in this thread
        call = self._async_stub.FullDuplexCall()

        for _ in range(_NUM_STREAM_RESPONSES):
            await call.write(request)
            response = await call.read()
            assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)

        await call.done_writing()
        assert await call.code() == grpc.StatusCode.OK

        # Calling sync API in a different thread
        def sync_work() -> None:
            response_iterator = self._sync_stub.FullDuplexCall(iter([request]))
            for response in response_iterator:
                assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
            self.assertEqual(grpc.StatusCode.OK, response_iterator.code())

        await self._run_in_another_thread(sync_work)
Ejemplo n.º 7
0
    async def test_max_message_length_applied(self):
        address, server = await start_test_server()

        async with aio.insecure_channel(
                address,
                options=((_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH,
                          _MAX_MESSAGE_LENGTH), )) as channel:
            stub = test_pb2_grpc.TestServiceStub(channel)

            request = messages_pb2.StreamingOutputCallRequest()
            # First request will pass
            request.response_parameters.append(
                messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH //
                                                2, ))
            # Second request should fail
            request.response_parameters.append(
                messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH *
                                                2, ))

            call = stub.StreamingOutputCall(request)

            response = await call.read()
            self.assertEqual(_MAX_MESSAGE_LENGTH // 2,
                             len(response.payload.body))

            with self.assertRaises(aio.AioRpcError) as exception_context:
                await call.read()
            rpc_error = exception_context.exception
            self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
                             rpc_error.code())
            self.assertIn(str(_MAX_MESSAGE_LENGTH), rpc_error.details())

            self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, await
                             call.code())

        await server.stop(None)
    async def test_cancel_by_the_interceptor(self):

        class Interceptor(aio.UnaryStreamClientInterceptor):

            async def intercept_unary_stream(self, continuation,
                                             client_call_details, request):
                call = await continuation(client_call_details, request)
                call.cancel()
                return call

        channel = aio.insecure_channel(UNREACHABLE_TARGET,
                                       interceptors=[Interceptor()])
        request = messages_pb2.StreamingOutputCallRequest()
        stub = test_pb2_grpc.TestServiceStub(channel)
        call = stub.StreamingOutputCall(request)

        with self.assertRaises(asyncio.CancelledError):
            async for response in call:
                pass

        self.assertTrue(call.cancelled())
        self.assertTrue(call.done())
        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
        await channel.close()
Ejemplo n.º 9
0
    async def test_stream_stream_using_async_gen(self):
        channel = aio.insecure_channel(self._server_target)
        stub = test_pb2_grpc.TestServiceStub(channel)

        # Prepares the request
        request = messages_pb2.StreamingOutputCallRequest()
        request.response_parameters.append(
            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))

        async def gen():
            for _ in range(_NUM_STREAM_RESPONSES):
                yield request

        # Invokes the actual RPC
        call = stub.FullDuplexCall(gen())

        async for response in call:
            self.assertIsInstance(response,
                                  messages_pb2.StreamingOutputCallResponse)
            self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
                             len(response.payload.body))

        self.assertEqual(grpc.StatusCode.OK, await call.code())
        await channel.close()
Ejemplo n.º 10
0
    async def test_timeout(self):
        call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S)

        # The error should be raised automatically without any traffic.
        with self.assertRaises(aio.AioRpcError) as exception_context:
            await call

        rpc_error = exception_context.exception
        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code())
        self.assertTrue(call.done())
        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code())


# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
    messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))


class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
    async def test_cancel(self):
        # Invokes the actual RPC
        call = self._stub.FullDuplexCall()

        for _ in range(_NUM_STREAM_RESPONSES):
            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
            response = await call.read()
            self.assertIsInstance(response,
                                  messages_pb2.StreamingOutputCallResponse)
            self.assertEqual(_RESPONSE_PAYLOAD_SIZE,