Beispiel #1
0
    async def test_cancel_before_rpc(self):

        interceptor_reached = asyncio.Event()
        wait_for_ever = self.loop.create_future()

        class Interceptor(aio.UnaryUnaryClientInterceptor):
            async def intercept_unary_unary(self, continuation,
                                            client_call_details, request):
                interceptor_reached.set()
                await wait_for_ever

        async with aio.insecure_channel(self._server_target,
                                        interceptors=[Interceptor()
                                                      ]) as channel:

            multicallable = channel.unary_unary(
                '/grpc.testing.TestService/UnaryCall',
                request_serializer=messages_pb2.SimpleRequest.
                SerializeToString,
                response_deserializer=messages_pb2.SimpleResponse.FromString)
            call = multicallable(messages_pb2.SimpleRequest())

            self.assertFalse(call.cancelled())
            self.assertFalse(call.done())

            await interceptor_reached.wait()
            self.assertTrue(call.cancel())

            with self.assertRaises(asyncio.CancelledError):
                await call

            self.assertTrue(call.cancelled())
            self.assertTrue(call.done())
            self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
            self.assertEqual(await call.details(),
                             _LOCAL_CANCEL_DETAILS_EXPECTATION)
            self.assertEqual(await call.initial_metadata(), None)
            self.assertEqual(await call.trailing_metadata(), None)
    async def test_executed_right_order(self):
        record = []
        server_target, _ = await start_test_server(interceptors=(
            _LoggingInterceptor('log1', record),
            _LoggingInterceptor('log2', record),
        ))

        async with aio.insecure_channel(server_target) as channel:
            multicallable = channel.unary_unary(
                '/grpc.testing.TestService/UnaryCall',
                request_serializer=messages_pb2.SimpleRequest.
                SerializeToString,
                response_deserializer=messages_pb2.SimpleResponse.FromString)
            call = multicallable(messages_pb2.SimpleRequest())
            response = await call

            # Check that all interceptors were executed, and were executed
            # in the right order.
            self.assertSequenceEqual([
                'log1:intercept_service',
                'log2:intercept_service',
            ], record)
            self.assertIsInstance(response, messages_pb2.SimpleResponse)
Beispiel #3
0
 async def test_maximum_concurrent_rpcs(self):
     # Build the server with concurrent rpc argument
     server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS)
     port = server.add_insecure_port('localhost:0')
     bind_address = "localhost:%d" % port
     server.add_generic_rpc_handlers((_GenericHandler(), ))
     await server.start()
     # Build the channel
     channel = aio.insecure_channel(bind_address)
     # Deplete the concurrent quota with 3 times of max RPCs
     rpcs = []
     for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS):
         rpcs.append(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST))
     task = self.loop.create_task(
         asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION))
     # Each batch took test_constants.SHORT_TIMEOUT /2
     start_time = time.time()
     await task
     elapsed_time = time.time() - start_time
     self.assertGreater(elapsed_time, test_constants.SHORT_TIMEOUT * 3 / 2)
     # Clean-up
     await channel.close()
     await server.stop(0)
Beispiel #4
0
    async def test_unary_stream_evilly_mixed(self):
        async with aio.insecure_channel(self._server_target) as channel:
            unary_stream_call = channel.unary_stream(
                _UNARY_STREAM_EVILLY_MIXED)
            call = unary_stream_call(_REQUEST)

            # Expecting the request message to reach server before retriving
            # any responses.
            await asyncio.wait_for(self._generic_handler.wait_for_call(),
                                   test_constants.SHORT_TIMEOUT)

            # Uses reader API
            self.assertEqual(_RESPONSE, await call.read())

            # Uses async generator API
            response_cnt = 0
            async for response in call:
                response_cnt += 1
                self.assertEqual(_RESPONSE, response)

            self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)

            self.assertEqual(await call.code(), grpc.StatusCode.OK)
Beispiel #5
0
    async def test_peer(self):
        @grpc.unary_unary_rpc_method_handler
        async def check_peer_unary_unary(request: bytes,
                                         context: aio.ServicerContext):
            self.assertEqual(_REQUEST, request)
            # The peer address could be ipv4 or ipv6
            self.assertIn('ip', context.peer())
            return request

        # Creates a server
        server = aio.server()
        handlers = grpc.method_handlers_generic_handler(
            'test', {'UnaryUnary': check_peer_unary_unary})
        server.add_generic_rpc_handlers((handlers, ))
        port = server.add_insecure_port('[::]:0')
        await server.start()

        # Creates a channel
        async with aio.insecure_channel('localhost:%d' % port) as channel:
            response = await channel.unary_unary(_TEST_METHOD)(_REQUEST)
            self.assertEqual(_REQUEST, response)

        await server.stop(None)
Beispiel #6
0
    async def setUp(self):
        self._async_server = aio.server(
            options=(('grpc.so_reuseport', 0), ),
            migration_thread_pool=ThreadPoolExecutor())

        test_pb2_grpc.add_TestServiceServicer_to_server(
            TestServiceServicer(), self._async_server)
        self._adhoc_handlers = _AdhocGenericHandler()
        self._async_server.add_generic_rpc_handlers((self._adhoc_handlers, ))

        port = self._async_server.add_insecure_port('[::]:0')
        address = 'localhost:%d' % port
        await self._async_server.start()

        # Create async stub
        self._async_channel = aio.insecure_channel(address,
                                                   options=_unique_options())
        self._async_stub = test_pb2_grpc.TestServiceStub(self._async_channel)

        # Create sync stub
        self._sync_channel = grpc.insecure_channel(address,
                                                   options=_unique_options())
        self._sync_stub = test_pb2_grpc.TestServiceStub(self._sync_channel)
Beispiel #7
0
    async def test_stream_stream_using_async_gen(self):
        channel = aio.insecure_channel(self._server_target)
        stub = test_pb2_grpc.TestServiceStub(channel)

        # Prepares the request
        request = messages_pb2.StreamingOutputCallRequest()
        request.response_parameters.append(
            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))

        async def gen():
            for _ in range(_NUM_STREAM_RESPONSES):
                yield request

        # Invokes the actual RPC
        call = stub.FullDuplexCall(gen())

        async for response in call:
            self.assertIsInstance(response,
                                  messages_pb2.StreamingOutputCallResponse)
            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))

        self.assertEqual(grpc.StatusCode.OK, await call.code())
        await channel.close()
Beispiel #8
0
    async def test_stream_unary_using_async_gen(self):
        channel = aio.insecure_channel(self._server_target)
        stub = test_pb2_grpc.TestServiceStub(channel)

        # Prepares the request
        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
        request = messages_pb2.StreamingInputCallRequest(payload=payload)

        async def gen():
            for _ in range(_NUM_STREAM_RESPONSES):
                yield request

        # Invokes the actual RPC
        call = stub.StreamingInputCall(gen())

        # Validates the responses
        response = await call
        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
                         response.aggregated_payload_size)

        self.assertEqual(await call.code(), grpc.StatusCode.OK)
        await channel.close()
Beispiel #9
0
    async def test_intercepts_response_iterator_rpc_error(self):
        for interceptor_class in (_UnaryStreamInterceptorEmpty,
                                  _UnaryStreamInterceptorWithResponseIterator):

            with self.subTest(name=interceptor_class):

                channel = aio.insecure_channel(
                    UNREACHABLE_TARGET, interceptors=[interceptor_class()])
                request = messages_pb2.StreamingOutputCallRequest()
                stub = test_pb2_grpc.TestServiceStub(channel)
                call = stub.StreamingOutputCall(request)

                with self.assertRaises(aio.AioRpcError) as exception_context:
                    async for response in call:
                        pass

                self.assertEqual(grpc.StatusCode.UNAVAILABLE,
                                 exception_context.exception.code())

                self.assertTrue(call.done())
                self.assertEqual(grpc.StatusCode.UNAVAILABLE, await
                                 call.code())
                await channel.close()
    async def kill_actor(self, req) -> aiohttp.web.Response:
        try:
            actor_id = req.query["actorId"]
            ip_address = req.query["ipAddress"]
            port = req.query["port"]
        except KeyError:
            return rest_response(success=False, message="Bad Request")
        try:
            channel = aiogrpc.insecure_channel(f"{ip_address}:{port}")
            stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)

            await stub.KillActor(
                core_worker_pb2.KillActorRequest(
                    intended_actor_id=ray.utils.hex_to_binary(actor_id)))

        except aiogrpc.AioRpcError:
            # This always throws an exception because the worker
            # is killed and the channel is closed on the worker side
            # before this handler, however it deletes the actor correctly.
            pass

        return rest_response(
            success=True, message=f"Killed actor with id {actor_id}")
Beispiel #11
0
    async def test_call_rpc_error(self):
        async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
            stub = test_pb2_grpc.TestServiceStub(channel)

            call = stub.UnaryCall(messages_pb2.SimpleRequest(), timeout=0.1)

            with self.assertRaises(grpc.RpcError) as exception_context:
                await call

            self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
                             exception_context.exception.code())

            self.assertTrue(call.done())
            self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
                             call.code())

            # Exception is cached at call object level, reentrance
            # returns again the same exception
            with self.assertRaises(grpc.RpcError) as exception_context_retry:
                await call

            self.assertIs(exception_context.exception,
                          exception_context_retry.exception)
    async def test_apply_different_interceptors_by_metadata(self):
        record = []
        conditional_interceptor = _filter_server_interceptor(
            lambda x: ('secret', '42') in x.invocation_metadata,
            _LoggingInterceptor('log3', record))
        server_target, _ = await start_test_server(interceptors=(
            _LoggingInterceptor('log1', record),
            conditional_interceptor,
            _LoggingInterceptor('log2', record),
        ))

        async with aio.insecure_channel(server_target) as channel:
            multicallable = channel.unary_unary(
                '/grpc.testing.TestService/UnaryCall',
                request_serializer=messages_pb2.SimpleRequest.
                SerializeToString,
                response_deserializer=messages_pb2.SimpleResponse.FromString)

            metadata = (('key', 'value'), )
            call = multicallable(messages_pb2.SimpleRequest(),
                                 metadata=metadata)
            await call
            self.assertSequenceEqual([
                'log1:intercept_service',
                'log2:intercept_service',
            ], record)

            record.clear()
            metadata = (('key', 'value'), ('secret', '42'))
            call = multicallable(messages_pb2.SimpleRequest(),
                                 metadata=metadata)
            await call
            self.assertSequenceEqual([
                'log1:intercept_service',
                'log3:intercept_service',
                'log2:intercept_service',
            ], record)
    async def test_intercepts(self):
        for interceptor_class in (_StreamUnaryInterceptorEmpty,
                                  _StreamUnaryInterceptorWithRequestIterator):

            with self.subTest(name=interceptor_class):
                interceptor = interceptor_class()
                channel = aio.insecure_channel(self._server_target,
                                               interceptors=[interceptor])
                stub = test_pb2_grpc.TestServiceStub(channel)

                payload = messages_pb2.Payload(body=b'\0' *
                                               _REQUEST_PAYLOAD_SIZE)
                request = messages_pb2.StreamingInputCallRequest(
                    payload=payload)

                async def request_iterator():
                    for _ in range(_NUM_STREAM_REQUESTS):
                        yield request

                call = stub.StreamingInputCall(request_iterator())

                response = await call

                self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
                                 response.aggregated_payload_size)
                self.assertEqual(await call.code(), grpc.StatusCode.OK)
                self.assertEqual(await call.initial_metadata(), aio.Metadata())
                self.assertEqual(await call.trailing_metadata(), aio.Metadata())
                self.assertEqual(await call.details(), '')
                self.assertEqual(await call.debug_error_string(), '')
                self.assertEqual(call.cancel(), False)
                self.assertEqual(call.cancelled(), False)
                self.assertEqual(call.done(), True)

                interceptor.assert_in_final_state(self)

                await channel.close()
    async def test_cancel_by_the_interceptor(self):

        class Interceptor(aio.UnaryStreamClientInterceptor):

            async def intercept_unary_stream(self, continuation,
                                             client_call_details, request):
                call = await continuation(client_call_details, request)
                call.cancel()
                return call

        channel = aio.insecure_channel(UNREACHABLE_TARGET,
                                       interceptors=[Interceptor()])
        request = messages_pb2.StreamingOutputCallRequest()
        stub = test_pb2_grpc.TestServiceStub(channel)
        call = stub.StreamingOutputCall(request)

        with self.assertRaises(asyncio.CancelledError):
            async for response in call:
                pass

        self.assertTrue(call.cancelled())
        self.assertTrue(call.done())
        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
        await channel.close()
Beispiel #15
0
    async def test_add_timeout(self):

        class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor):
            """Interceptor used for adding a timeout to the RPC"""

            async def intercept_unary_unary(self, continuation,
                                            client_call_details, request):
                new_client_call_details = aio.ClientCallDetails(
                    method=client_call_details.method,
                    timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
                    metadata=client_call_details.metadata,
                    credentials=client_call_details.credentials,
                    wait_for_ready=client_call_details.wait_for_ready)
                return await continuation(new_client_call_details, request)

        interceptor = TimeoutInterceptor()

        async with aio.insecure_channel(self._server_target,
                                        interceptors=[interceptor]) as channel:

            multicallable = channel.unary_unary(
                '/grpc.testing.TestService/UnaryCallWithSleep',
                request_serializer=messages_pb2.SimpleRequest.SerializeToString,
                response_deserializer=messages_pb2.SimpleResponse.FromString)

            call = multicallable(messages_pb2.SimpleRequest())

            with self.assertRaises(aio.AioRpcError) as exception_context:
                await call

            self.assertEqual(exception_context.exception.code(),
                             grpc.StatusCode.DEADLINE_EXCEEDED)

            self.assertTrue(call.done())
            self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
                             call.code())
Beispiel #16
0
    async def test_unary_stream(self):
        channel = aio.insecure_channel(self._server_target)
        stub = test_pb2_grpc.TestServiceStub(channel)

        # Prepares the request
        request = messages_pb2.StreamingOutputCallRequest()
        for _ in range(_NUM_STREAM_RESPONSES):
            request.response_parameters.append(
                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))

        # Invokes the actual RPC
        call = stub.StreamingOutputCall(request)

        # Validates the responses
        response_cnt = 0
        async for response in call:
            response_cnt += 1
            self.assertIs(type(response),
                          messages_pb2.StreamingOutputCallResponse)
            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))

        self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
        self.assertEqual(await call.code(), grpc.StatusCode.OK)
        await channel.close()
Beispiel #17
0
 def __init__(self,
              node_ip_address,
              redis_address,
              dashboard_agent_port,
              redis_password=None,
              temp_dir=None,
              log_dir=None,
              metrics_export_port=None,
              node_manager_port=None,
              object_store_name=None,
              raylet_name=None):
     """Initialize the DashboardAgent object."""
     # Public attributes are accessible for all agent modules.
     self.ip = node_ip_address
     self.redis_address = dashboard_utils.address_tuple(redis_address)
     self.redis_password = redis_password
     self.temp_dir = temp_dir
     self.log_dir = log_dir
     self.dashboard_agent_port = dashboard_agent_port
     self.metrics_export_port = metrics_export_port
     self.node_manager_port = node_manager_port
     self.object_store_name = object_store_name
     self.raylet_name = raylet_name
     self.node_id = os.environ["RAY_NODE_ID"]
     self.ppid = int(os.environ["RAY_RAYLET_PID"])
     assert self.ppid > 0
     logger.info("Parent pid is %s", self.ppid)
     self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
     self.grpc_port = self.server.add_insecure_port(
         f"[::]:{self.dashboard_agent_port}")
     logger.info("Dashboard agent grpc address: %s:%s", self.ip,
                 self.grpc_port)
     self.aioredis_client = None
     self.aiogrpc_raylet_channel = aiogrpc.insecure_channel(
         f"{self.ip}:{self.node_manager_port}")
     self.http_session = None
Beispiel #18
0
        async def coro():
            server_target, _ = await start_test_server()  # pylint: disable=unused-variable

            async with aio.insecure_channel(server_target) as channel:
                hi = channel.unary_unary(
                    '/grpc.testing.TestService/UnaryCall',
                    request_serializer=messages_pb2.SimpleRequest.
                    SerializeToString,
                    response_deserializer=messages_pb2.SimpleResponse.FromString
                )
                call = hi(messages_pb2.SimpleRequest())

                self.assertFalse(call.done())

                response = await call

                self.assertTrue(call.done())
                self.assertEqual(type(response), messages_pb2.SimpleResponse)
                self.assertEqual(await call.code(), grpc.StatusCode.OK)

                # Response is cached at call object level, reentrance
                # returns again the same response
                response_retry = await call
                self.assertIs(response, response_retry)
Beispiel #19
0
    async def test_max_message_length_applied(self):
        address, server = await start_test_server()

        async with aio.insecure_channel(
                address,
                options=((_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH,
                          _MAX_MESSAGE_LENGTH), )) as channel:
            stub = test_pb2_grpc.TestServiceStub(channel)

            request = messages_pb2.StreamingOutputCallRequest()
            # First request will pass
            request.response_parameters.append(
                messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH //
                                                2, ))
            # Second request should fail
            request.response_parameters.append(
                messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH *
                                                2, ))

            call = stub.StreamingOutputCall(request)

            response = await call.read()
            self.assertEqual(_MAX_MESSAGE_LENGTH // 2,
                             len(response.payload.body))

            with self.assertRaises(aio.AioRpcError) as exception_context:
                await call.read()
            rpc_error = exception_context.exception
            self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
                             rpc_error.code())
            self.assertIn(str(_MAX_MESSAGE_LENGTH), rpc_error.details())

            self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, await
                             call.code())

        await server.stop(None)
Beispiel #20
0
 async def test_client(self):
     # Do not segfault, or raise exception!
     channel = aio.insecure_channel('[::]:0', options=_TEST_CHANNEL_ARGS)
     await channel.close()
Beispiel #21
0
    async def test_retry(self):

        class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
            """Simulates a Retry Interceptor which ends up by making
            two RPC calls."""

            def __init__(self):
                self.calls = []

            async def intercept_unary_unary(self, continuation,
                                            client_call_details, request):

                new_client_call_details = aio.ClientCallDetails(
                    method=client_call_details.method,
                    timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
                    metadata=client_call_details.metadata,
                    credentials=client_call_details.credentials,
                    wait_for_ready=client_call_details.wait_for_ready)

                try:
                    call = await continuation(new_client_call_details, request)
                    await call
                except grpc.RpcError:
                    pass

                self.calls.append(call)

                new_client_call_details = aio.ClientCallDetails(
                    method=client_call_details.method,
                    timeout=None,
                    metadata=client_call_details.metadata,
                    credentials=client_call_details.credentials,
                    wait_for_ready=client_call_details.wait_for_ready)

                call = await continuation(new_client_call_details, request)
                self.calls.append(call)
                return call

        interceptor = RetryInterceptor()

        async with aio.insecure_channel(self._server_target,
                                        interceptors=[interceptor]) as channel:

            multicallable = channel.unary_unary(
                '/grpc.testing.TestService/UnaryCallWithSleep',
                request_serializer=messages_pb2.SimpleRequest.SerializeToString,
                response_deserializer=messages_pb2.SimpleResponse.FromString)

            call = multicallable(messages_pb2.SimpleRequest())

            await call

            self.assertEqual(grpc.StatusCode.OK, await call.code())

            # Check that two calls were made, first one finishing with
            # a deadline and second one finishing ok..
            self.assertEqual(len(interceptor.calls), 2)
            self.assertEqual(await interceptor.calls[0].code(),
                             grpc.StatusCode.DEADLINE_EXCEEDED)
            self.assertEqual(await interceptor.calls[1].code(),
                             grpc.StatusCode.OK)
Beispiel #22
0
 async def setUp(self):
     address, self._server = await _start_test_server()
     self._client = aio.insecure_channel(address)
Beispiel #23
0
    async def run(self):
        # Create an aioredis client for all modules.
        try:
            self.aioredis_client = await dashboard_utils.get_aioredis_client(
                self.redis_address, self.redis_password,
                dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
                dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
        except (socket.gaierror, ConnectionError):
            logger.error(
                "Dashboard head exiting: "
                "Failed to connect to redis at %s", self.redis_address)
            sys.exit(-1)

        # Create a http session for all modules.
        self.http_session = aiohttp.ClientSession(
            loop=asyncio.get_event_loop())

        # Waiting for GCS is ready.
        while True:
            try:
                gcs_address = await self.aioredis_client.get(
                    dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
                if not gcs_address:
                    raise Exception("GCS address not found.")
                logger.info("Connect to GCS at %s", gcs_address)
                options = (("grpc.enable_http_proxy", 0), )
                channel = aiogrpc.insecure_channel(gcs_address,
                                                   options=options)
            except Exception as ex:
                logger.error("Connect to GCS failed: %s, retry...", ex)
                await asyncio.sleep(
                    dashboard_consts.CONNECT_GCS_INTERVAL_SECONDS)
            else:
                self.aiogrpc_gcs_channel = channel
                break

        # Create a NodeInfoGcsServiceStub.
        self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
            self.aiogrpc_gcs_channel)

        # Start a grpc asyncio server.
        await self.server.start()

        # Write the dashboard head port to redis.
        await self.aioredis_client.set(dashboard_consts.REDIS_KEY_DASHBOARD,
                                       self.ip + ":" + str(self.http_port))
        await self.aioredis_client.set(
            dashboard_consts.REDIS_KEY_DASHBOARD_RPC,
            self.ip + ":" + str(self.grpc_port))

        async def _async_notify():
            """Notify signals from queue."""
            while True:
                co = await dashboard_utils.NotifyQueue.get()
                try:
                    await co
                except Exception:
                    logger.exception(f"Error notifying coroutine {co}")

        modules = self._load_modules()

        # Http server should be initialized after all modules loaded.
        app = aiohttp.web.Application()
        app.add_routes(routes=routes.bound_routes())
        web_server = aiohttp.web._run_app(app,
                                          host=self.http_host,
                                          port=self.http_port)

        # Dump registered http routes.
        dump_routes = [
            r for r in app.router.routes() if r.method != hdrs.METH_HEAD
        ]
        for r in dump_routes:
            logger.info(r)
        logger.info("Registered %s routes.", len(dump_routes))

        # Freeze signal after all modules loaded.
        dashboard_utils.SignalManager.freeze()
        concurrent_tasks = [
            self._update_nodes(),
            _async_notify(),
            DataOrganizer.purge(),
            DataOrganizer.organize(),
            web_server,
        ]
        await asyncio.gather(*concurrent_tasks,
                             *(m.run(self.server) for m in modules))
        await self.server.wait_for_termination()
Beispiel #24
0
 async def coro():
     channel = aio.insecure_channel(self.server_target)
     self.assertIsInstance(channel, aio.Channel)
Beispiel #25
0
    def test_invalid_interceptor(self):
        class InvalidInterceptor:
            """Just an invalid Interceptor"""

        with self.assertRaises(ValueError):
            aio.insecure_channel("", interceptors=[InvalidInterceptor()])
Beispiel #26
0
 async def setUp(self):
     address, self._server = await start_test_server()
     self._channel = aio.insecure_channel(address)
     self._stub = test_pb2_grpc.TestServiceStub(self._channel)
Beispiel #27
0
 async def setUp(self):
     addr, self._server, self._generic_handler = await _start_test_server()
     self._channel = aio.insecure_channel(addr)
Beispiel #28
0
 async def setUp(self):
     address, self._port, self._socket = get_socket(
         listen=False, sock_options=(socket.SO_REUSEADDR, ))
     self._channel = aio.insecure_channel(f"{address}:{self._port}")
     self._socket.close()
Beispiel #29
0
    async def run(self):
        # Create an aioredis client for all modules.
        try:
            self.aioredis_client = await dashboard_utils.get_aioredis_client(
                self.redis_address, self.redis_password,
                dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
                dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
        except (socket.gaierror, ConnectionError):
            logger.error(
                "Dashboard head exiting: "
                "Failed to connect to redis at %s", self.redis_address)
            sys.exit(-1)

        # Create a http session for all modules.
        self.http_session = aiohttp.ClientSession(
            loop=asyncio.get_event_loop())

        # Waiting for GCS is ready.
        while True:
            try:
                gcs_address = await self.aioredis_client.get(
                    dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
                if not gcs_address:
                    raise Exception("GCS address not found.")
                logger.info("Connect to GCS at %s", gcs_address)
                options = (("grpc.enable_http_proxy", 0), )
                channel = aiogrpc.insecure_channel(
                    gcs_address, options=options)
            except Exception as ex:
                logger.error("Connect to GCS failed: %s, retry...", ex)
                await asyncio.sleep(
                    dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
            else:
                self.aiogrpc_gcs_channel = channel
                break

        # Create a HeartbeatInfoGcsServiceStub.
        self._gcs_heartbeat_info_stub = \
            gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
                self.aiogrpc_gcs_channel)

        # Start a grpc asyncio server.
        await self.server.start()

        async def _async_notify():
            """Notify signals from queue."""
            while True:
                co = await dashboard_utils.NotifyQueue.get()
                try:
                    await co
                except Exception:
                    logger.exception(f"Error notifying coroutine {co}")

        modules = self._load_modules()

        # Http server should be initialized after all modules loaded.
        app = aiohttp.web.Application()
        app.add_routes(routes=routes.bound_routes())

        runner = aiohttp.web.AppRunner(app)
        await runner.setup()
        last_ex = None
        for i in range(1 + self.http_port_retries):
            try:
                site = aiohttp.web.TCPSite(runner, self.http_host,
                                           self.http_port)
                await site.start()
                break
            except OSError as e:
                last_ex = e
                self.http_port += 1
                logger.warning("Try to use port %s: %s", self.http_port, e)
        else:
            raise Exception(f"Failed to find a valid port for dashboard after "
                            f"{self.http_port_retries} retries: {last_ex}")
        http_host, http_port, *_ = site._server.sockets[0].getsockname()
        http_host = self.ip if ipaddress.ip_address(
            http_host).is_unspecified else http_host
        logger.info("Dashboard head http address: %s:%s", http_host, http_port)

        # Write the dashboard head port to redis.
        await self.aioredis_client.set(ray_constants.REDIS_KEY_DASHBOARD,
                                       f"{http_host}:{http_port}")
        await self.aioredis_client.set(
            dashboard_consts.REDIS_KEY_DASHBOARD_RPC,
            f"{self.ip}:{self.grpc_port}")

        # Dump registered http routes.
        dump_routes = [
            r for r in app.router.routes() if r.method != hdrs.METH_HEAD
        ]
        for r in dump_routes:
            logger.info(r)
        logger.info("Registered %s routes.", len(dump_routes))

        # Freeze signal after all modules loaded.
        dashboard_utils.SignalManager.freeze()
        concurrent_tasks = [
            self._gcs_check_alive(),
            _async_notify(),
            DataOrganizer.purge(),
            DataOrganizer.organize(),
        ]
        await asyncio.gather(*concurrent_tasks,
                             *(m.run(self.server) for m in modules))
        await self.server.wait_for_termination()
Beispiel #30
0
    async def test_insecure_channel(self):
        server_target, _ = await start_test_server()  # pylint: disable=unused-variable

        channel = aio.insecure_channel(server_target)
        self.assertIsInstance(channel, aio.Channel)