Ejemplo n.º 1
0
class GatewayRuntime(AsyncNewLoopRuntime, ABC):
    """
    The Runtime from which the GatewayRuntimes need to inherit
    """

    def __init__(
        self,
        args: argparse.Namespace,
        cancel_event: Optional[
            Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event']
        ] = None,
        **kwargs,
    ):
        # this order is intentional: The timeout is needed in _set_topology_graph(), called by super
        self.timeout_send = args.timeout_send
        if self.timeout_send:
            self.timeout_send /= 1e3  # convert ms to seconds
        super().__init__(args, cancel_event, **kwargs)

    def _set_topology_graph(self):
        # check if it should be in K8s, maybe ConnectionPoolFactory to be created
        import json

        graph_description = json.loads(self.args.graph_description)
        graph_conditions = json.loads(self.args.graph_conditions)
        deployments_disable_reduce = json.loads(self.args.deployments_disable_reduce)
        self._topology_graph = TopologyGraph(
            graph_description,
            graph_conditions,
            deployments_disable_reduce,
            timeout_send=self.timeout_send,
            retries=self.args.retries,
        )

    def _set_connection_pool(self):
        import json

        deployments_addresses = json.loads(self.args.deployments_addresses)
        # add the connections needed
        self._connection_pool = GrpcConnectionPool(
            logger=self.logger,
            compression=self.args.compression,
            metrics_registry=self.metrics_registry,
        )
        for deployment_name, addresses in deployments_addresses.items():
            for address in addresses:
                self._connection_pool.add_connection(
                    deployment=deployment_name, address=address, head=True
                )
Ejemplo n.º 2
0
async def test_connection_pool(mocker, monkeypatch):
    close_mock_object, create_mock = await _mock_grpc(mocker, monkeypatch)

    pool = GrpcConnectionPool()
    send_mock = mocker.Mock()
    pool._send_requests = lambda messages, connection, endpoint: mock_send(
        send_mock)

    pool.add_connection(deployment='encoder', head=False, address='1.1.1.1:53')
    pool.add_connection(deployment='encoder', head=False, address='1.1.1.2:53')
    results = pool.send_request(request=ControlRequest(command='STATUS'),
                                deployment='encoder',
                                head=False)
    assert len(results) == 1
    assert send_mock.call_count == 1
    assert create_mock.call_count == 2

    results = pool.send_request(request=ControlRequest(command='STATUS'),
                                deployment='encoder',
                                head=False)
    assert len(results) == 1
    assert send_mock.call_count == 2
    assert create_mock.call_count == 2

    # indexer was not added yet, so there isnt anything being sent
    results = pool.send_request(request=ControlRequest(command='STATUS'),
                                deployment='indexer',
                                head=False)
    assert len(results) == 0
    assert send_mock.call_count == 2
    assert create_mock.call_count == 2

    # add indexer now so it can be send
    pool.add_connection(deployment='indexer', head=False, address='2.1.1.1:53')
    results = pool.send_request(request=ControlRequest(command='STATUS'),
                                deployment='indexer',
                                head=False)
    assert len(results) == 1
    assert send_mock.call_count == 3
    assert create_mock.call_count == 3

    # polling only applies to shards, there are no shards here, so it only sends one message
    pool.add_connection(deployment='encoder', head=False, address='1.1.1.3:53')
    results = pool.send_request(
        request=ControlRequest(command='STATUS'),
        deployment='encoder',
        head=False,
        polling_type=PollingType.ALL,
    )
    assert len(results) == 1
    assert send_mock.call_count == 4
    assert create_mock.call_count == 4

    # polling only applies to shards, so we add a shard now and expect 2 messages being sent
    pool.add_connection(deployment='encoder',
                        head=False,
                        address='1.1.1.3:53',
                        shard_id=1)
    # adding the same connection again is a noop
    pool.add_connection(deployment='encoder',
                        head=False,
                        address='1.1.1.3:53',
                        shard_id=1)
    results = pool.send_request(
        request=ControlRequest(command='STATUS'),
        deployment='encoder',
        head=False,
        polling_type=PollingType.ALL,
    )
    assert len(results) == 2
    assert send_mock.call_count == 6
    assert create_mock.call_count == 5

    # sending to one specific shard should only send one message
    results = pool.send_request(
        request=ControlRequest(command='STATUS'),
        deployment='encoder',
        head=False,
        polling_type=PollingType.ANY,
        shard_id=1,
    )
    assert len(results) == 1
    assert send_mock.call_count == 7

    # doing the same with polling ALL ignores the shard id
    results = pool.send_request(
        request=ControlRequest(command='STATUS'),
        deployment='encoder',
        head=False,
        polling_type=PollingType.ALL,
        shard_id=1,
    )
    assert len(results) == 2
    assert send_mock.call_count == 9

    # removing a replica for shard 0 works and does not prevent messages to be sent to the shard
    assert await pool.remove_connection(deployment='encoder',
                                        head=False,
                                        address='1.1.1.2:53',
                                        shard_id=0)
    assert close_mock_object.call_count == 1
    results = pool.send_request(
        request=ControlRequest(command='STATUS'),
        deployment='encoder',
        head=False,
        polling_type=PollingType.ANY,
        shard_id=0,
    )
    assert len(results) == 1
    assert send_mock.call_count == 10

    # encoder pod has no head registered yet so sending to the head will not work
    results = pool.send_request(request=ControlRequest(command='STATUS'),
                                deployment='encoder',
                                head=True)
    assert len(results) == 0
    assert send_mock.call_count == 10

    # after registering a head for encoder, sending to head should work
    pool.add_connection(deployment='encoder', head=True, address='1.1.1.10:53')
    results = pool.send_request(request=ControlRequest(command='STATUS'),
                                deployment='encoder',
                                head=True)
    assert len(results) == 1
    assert send_mock.call_count == 11

    # after remove the head again, sending will not work
    assert await pool.remove_connection(deployment='encoder',
                                        head=True,
                                        address='1.1.1.10:53')
    assert close_mock_object.call_count == 2
    results = pool.send_request(request=ControlRequest(command='STATUS'),
                                deployment='encoder',
                                head=True)
    assert len(results) == 0
    assert send_mock.call_count == 11

    # check that remove/add order is handled well
    pool.add_connection(deployment='encoder', head=False, address='1.1.1.4:53')
    assert await pool.remove_connection(deployment='encoder',
                                        head=False,
                                        address='1.1.1.1:53')
    assert await pool.remove_connection(deployment='encoder',
                                        head=False,
                                        address='1.1.1.4:53')
    assert close_mock_object.call_count == 4
    assert not (await pool.remove_connection(
        deployment='encoder', head=False, address='1.1.1.2:53'))

    await pool.close()
Ejemplo n.º 3
0
async def test_grpc_connection_pool_real_sending_timeout():
    server1_ready_event = multiprocessing.Event()

    def listen(port, event: multiprocessing.Event):
        class DummyServer:
            async def process_control(self, request, *args):
                returned_msg = ControlRequest(command='DEACTIVATE')
                await asyncio.sleep(0.1)
                return returned_msg

        async def start_grpc_server():
            grpc_server = grpc.aio.server(options=[
                ('grpc.max_send_request_length', -1),
                ('grpc.max_receive_message_length', -1),
            ])

            jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server(
                DummyServer(), grpc_server)
            service_names = (
                jina_pb2.DESCRIPTOR.services_by_name['JinaControlRequestRPC'].
                full_name,
                reflection.SERVICE_NAME,
            )
            reflection.enable_server_reflection(service_names, grpc_server)
            grpc_server.add_insecure_port(f'localhost:{port}')

            await grpc_server.start()
            event.set()
            await grpc_server.wait_for_termination()

        asyncio.run(start_grpc_server())

    port1 = random_port()
    server_process1 = Process(
        target=listen,
        args=(
            port1,
            server1_ready_event,
        ),
    )
    server_process1.start()

    time.sleep(0.1)
    server1_ready_event.wait()

    pool = GrpcConnectionPool()

    pool.add_connection(deployment='encoder',
                        head=False,
                        address=f'localhost:{port1}')
    sent_msg = ControlRequest(command='STATUS')

    results_call_1 = pool.send_request(request=sent_msg,
                                       deployment='encoder',
                                       head=False,
                                       timeout=1.0)

    assert len(results_call_1) == 1
    response1, meta = await results_call_1[0]
    assert response1.command == 'DEACTIVATE'

    results_call_2 = pool.send_request(request=sent_msg,
                                       deployment='encoder',
                                       head=False,
                                       timeout=0.05)
    assert len(results_call_2) == 1
    with pytest.raises(AioRpcError):
        await results_call_2[0]

    await pool.close()
    server_process1.kill()
    server_process1.join()
Ejemplo n.º 4
0
class HeadRuntime(AsyncNewLoopRuntime, ABC):
    """
    Runtime is used in head pods. It responds to Gateway requests and sends to uses_before/uses_after and its workers
    """

    DEFAULT_POLLING = PollingType.ANY

    def __init__(
        self,
        args: argparse.Namespace,
        **kwargs,
    ):
        """Initialize grpc server for the head runtime.
        :param args: args from CLI
        :param kwargs: keyword args
        """
        self._health_servicer = health.HealthServicer(
            experimental_non_blocking=True)

        super().__init__(args, **kwargs)
        if args.name is None:
            args.name = ''
        self.name = args.name
        self._deployment_name = os.getenv('JINA_DEPLOYMENT_NAME', 'worker')
        self.connection_pool = GrpcConnectionPool(
            logger=self.logger,
            compression=args.compression,
            metrics_registry=self.metrics_registry,
        )
        self._retries = self.args.retries

        if self.metrics_registry:
            with ImportExtensions(
                    required=True,
                    help_text=
                    'You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Summary

            self._summary = (Summary(
                'receiving_request_seconds',
                'Time spent processing request',
                registry=self.metrics_registry,
                namespace='jina',
                labelnames=('runtime_name', ),
            ).labels(self.args.name).time())
        else:
            self._summary = contextlib.nullcontext()

        polling = getattr(args, 'polling', self.DEFAULT_POLLING.name)
        try:
            # try loading the polling args as json
            endpoint_polling = json.loads(polling)
            # '*' is used a wildcard and will match all endpoints, except /index, /search and explicitly defined endpoins
            default_polling = (PollingType.from_string(endpoint_polling['*'])
                               if '*' in endpoint_polling else
                               self.DEFAULT_POLLING)
            self._polling = self._default_polling_dict(default_polling)
            for endpoint in endpoint_polling:
                self._polling[endpoint] = PollingType(
                    endpoint_polling[endpoint] if type(
                        endpoint_polling[endpoint]) == int else PollingType.
                    from_string(endpoint_polling[endpoint]))
        except (ValueError, TypeError):
            # polling args is not a valid json, try interpreting as a polling enum type
            default_polling = (polling if type(polling) == PollingType else
                               PollingType.from_string(polling))
            self._polling = self._default_polling_dict(default_polling)

        if hasattr(args, 'connection_list') and args.connection_list:
            connection_list = json.loads(args.connection_list)
            for shard_id in connection_list:
                shard_connections = connection_list[shard_id]
                if isinstance(shard_connections, str):
                    self.connection_pool.add_connection(
                        deployment=self._deployment_name,
                        address=shard_connections,
                        shard_id=int(shard_id),
                    )
                else:
                    for connection in shard_connections:
                        self.connection_pool.add_connection(
                            deployment=self._deployment_name,
                            address=connection,
                            shard_id=int(shard_id),
                        )

        self.uses_before_address = args.uses_before_address
        self.timeout_send = args.timeout_send
        if self.timeout_send:
            self.timeout_send /= 1e3  # convert ms to seconds

        if self.uses_before_address:
            self.connection_pool.add_connection(
                deployment='uses_before', address=self.uses_before_address)
        self.uses_after_address = args.uses_after_address
        if self.uses_after_address:
            self.connection_pool.add_connection(
                deployment='uses_after', address=self.uses_after_address)
        self._reduce = not args.disable_reduce

    def _default_polling_dict(self, default_polling):
        return defaultdict(
            lambda: default_polling,
            {
                '/search': PollingType.ALL,
                '/index': PollingType.ANY
            },
        )

    async def async_setup(self):
        """Wait for the GRPC server to start"""
        self._grpc_server = grpc.aio.server(options=[
            ('grpc.max_send_message_length', -1),
            ('grpc.max_receive_message_length', -1),
        ])

        jina_pb2_grpc.add_JinaSingleDataRequestRPCServicer_to_server(
            self, self._grpc_server)
        jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(
            self, self._grpc_server)
        jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server(
            self, self._grpc_server)
        jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(
            self, self._grpc_server)
        service_names = (
            jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].
            full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].
            full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].
            full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name,
            reflection.SERVICE_NAME,
        )
        # Mark all services as healthy.
        health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer,
                                                     self._grpc_server)

        for service in service_names:
            self._health_servicer.set(service,
                                      health_pb2.HealthCheckResponse.SERVING)
        reflection.enable_server_reflection(service_names, self._grpc_server)

        bind_addr = f'0.0.0.0:{self.args.port}'
        self._grpc_server.add_insecure_port(bind_addr)
        self.logger.debug(f'start listening on {bind_addr}')
        await self._grpc_server.start()

    async def async_run_forever(self):
        """Block until the GRPC server is terminated"""
        self.connection_pool.start()
        await self._grpc_server.wait_for_termination()

    async def async_cancel(self):
        """Stop the GRPC server"""
        self.logger.debug('cancel HeadRuntime')

        await self._grpc_server.stop(0)

    async def async_teardown(self):
        """Close the connection pool"""
        self._health_servicer.enter_graceful_shutdown()
        await self.async_cancel()
        await self.connection_pool.close()

    async def process_single_data(self, request: DataRequest,
                                  context) -> DataRequest:
        """
        Process the received requests and return the result as a new request

        :param request: the data request to process
        :param context: grpc context
        :returns: the response request
        """
        return await self.process_data([request], context)

    def _handle_internalnetworkerror(self, err, context, response):
        err_code = err.code()
        if err_code == grpc.StatusCode.UNAVAILABLE:
            context.set_details(
                f'|Head: Failed to connect to worker (Executor) pod at address {err.dest_addr}. It may be down.'
            )
        elif err_code == grpc.StatusCode.DEADLINE_EXCEEDED:
            context.set_details(
                f'|Head: Connection to worker (Executor) pod at address {err.dest_addr} could be established, but timed out.'
            )
        context.set_code(err.code())
        self.logger.error(
            f'Error while getting responses from Pods: {err.details()}')
        if err.request_id:
            response.header.request_id = err.request_id
        return response

    async def process_data(self, requests: List[DataRequest],
                           context) -> DataRequest:
        """
        Process the received data request and return the result as a new request

        :param requests: the data requests to process
        :param context: grpc context
        :returns: the response request
        """
        try:
            with self._summary:
                endpoint = dict(context.invocation_metadata()).get('endpoint')
                response, metadata = await self._handle_data_request(
                    requests, endpoint)
                context.set_trailing_metadata(metadata.items())
                return response
        except InternalNetworkError as err:  # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism
            return self._handle_internalnetworkerror(err=err,
                                                     context=context,
                                                     response=Response())
        except (
                RuntimeError,
                Exception,
        ) as ex:  # some other error, keep streaming going just add error info
            self.logger.error(
                f'{ex!r}' +
                f'\n add "--quiet-error" to suppress the exception details'
                if not self.args.quiet_error else '',
                exc_info=not self.args.quiet_error,
            )
            requests[0].add_exception(ex, executor=None)
            context.set_trailing_metadata((('is-error', 'true'), ))
            return requests[0]

    async def endpoint_discovery(self, empty,
                                 context) -> jina_pb2.EndpointsProto:
        """
        Uses the connection pool to send a discover endpoint call to the workers

        :param empty: The service expects an empty protobuf message
        :param context: grpc context
        :returns: the response request
        """
        response = jina_pb2.EndpointsProto()
        try:
            if self.uses_before_address:
                (
                    uses_before_response,
                    _,
                ) = await self.connection_pool.send_discover_endpoint(
                    deployment='uses_before', head=False)
                response.endpoints.extend(uses_before_response.endpoints)
            if self.uses_after_address:
                (
                    uses_after_response,
                    _,
                ) = await self.connection_pool.send_discover_endpoint(
                    deployment='uses_after', head=False)
                response.endpoints.extend(uses_after_response.endpoints)

            worker_response, _ = await self.connection_pool.send_discover_endpoint(
                deployment=self._deployment_name, head=False)
            response.endpoints.extend(worker_response.endpoints)
        except InternalNetworkError as err:  # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism
            return self._handle_internalnetworkerror(err=err,
                                                     context=context,
                                                     response=response)

        return response

    async def _handle_data_request(
            self, requests: List[DataRequest],
            endpoint: Optional[str]) -> Tuple[DataRequest, Dict]:
        self.logger.debug(f'recv {len(requests)} DataRequest(s)')

        DataRequestHandler.merge_routes(requests)

        uses_before_metadata = None
        if self.uses_before_address:
            (
                response,
                uses_before_metadata,
            ) = await self.connection_pool.send_requests_once(
                requests,
                deployment='uses_before',
                timeout=self.timeout_send,
                retries=self._retries,
            )
            requests = [response]

        worker_send_tasks = self.connection_pool.send_requests(
            requests=requests,
            deployment=self._deployment_name,
            polling_type=self._polling[endpoint],
            timeout=self.timeout_send,
            retries=self._retries,
        )

        worker_results = await asyncio.gather(*worker_send_tasks)

        if len(worker_results) == 0:
            raise RuntimeError(
                f'Head {self.name} did not receive a response when sending message to worker pods'
            )

        worker_results, metadata = zip(*worker_results)

        response_request = worker_results[0]
        uses_after_metadata = None
        if self.uses_after_address:
            (
                response_request,
                uses_after_metadata,
            ) = await self.connection_pool.send_requests_once(
                worker_results,
                deployment='uses_after',
                timeout=self.timeout_send,
                retries=self._retries,
            )
        elif len(worker_results) > 1 and self._reduce:
            DataRequestHandler.reduce_requests(worker_results)
        elif len(worker_results) > 1 and not self._reduce:
            # worker returned multiple responsed, but the head is configured to skip reduction
            # just concatenate the docs in this case
            response_request.data.docs = DataRequestHandler.get_docs_from_request(
                requests, field='docs')

        merged_metadata = self._merge_metadata(metadata, uses_after_metadata,
                                               uses_before_metadata)

        return response_request, merged_metadata

    def _merge_metadata(self, metadata, uses_after_metadata,
                        uses_before_metadata):
        merged_metadata = {}
        if uses_before_metadata:
            for key, value in uses_before_metadata:
                merged_metadata[key] = value
        for meta in metadata:
            for key, value in meta:
                merged_metadata[key] = value
        if uses_after_metadata:
            for key, value in uses_after_metadata:
                merged_metadata[key] = value
        return merged_metadata

    async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
        """
        Process the the call requested and return the JinaInfo of the Runtime

        :param empty: The service expects an empty protobuf message
        :param context: grpc context
        :returns: the response request
        """
        infoProto = jina_pb2.JinaInfoProto()
        version, env_info = get_full_version()
        for k, v in version.items():
            infoProto.jina[k] = str(v)
        for k, v in env_info.items():
            infoProto.envs[k] = str(v)
        return infoProto
Ejemplo n.º 5
0
async def test_grpc_connection_pool_real_sending():
    server1_ready_event = multiprocessing.Event()
    server2_ready_event = multiprocessing.Event()

    def listen(port, event: multiprocessing.Event):
        class DummyServer:
            async def process_control(self, request, *args):
                returned_msg = ControlRequest(command='DEACTIVATE')
                return returned_msg

        async def start_grpc_server():
            grpc_server = grpc.aio.server(
                options=[
                    ('grpc.max_send_request_length', -1),
                    ('grpc.max_receive_message_length', -1),
                ]
            )

            jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server(
                DummyServer(), grpc_server
            )
            grpc_server.add_insecure_port(f'localhost:{port}')

            await grpc_server.start()
            event.set()
            await grpc_server.wait_for_termination()

        asyncio.run(start_grpc_server())

    port1 = random_port()
    server_process1 = Process(
        target=listen,
        args=(
            port1,
            server1_ready_event,
        ),
    )
    server_process1.start()

    port2 = random_port()
    server_process2 = Process(
        target=listen,
        args=(
            port2,
            server2_ready_event,
        ),
    )
    server_process2.start()

    time.sleep(0.1)
    server1_ready_event.wait()
    server2_ready_event.wait()

    pool = GrpcConnectionPool()

    pool.add_connection(deployment='encoder', head=False, address=f'localhost:{port1}')
    pool.add_connection(deployment='encoder', head=False, address=f'localhost:{port2}')
    sent_msg = ControlRequest(command='STATUS')

    results_call_1 = pool.send_request(
        request=sent_msg, deployment='encoder', head=False
    )
    results_call_2 = pool.send_request(
        request=sent_msg, deployment='encoder', head=False
    )
    assert len(results_call_1) == 1
    assert len(results_call_2) == 1

    response1, meta = await results_call_1[0]
    assert response1.command == 'DEACTIVATE'

    response2, meta = await results_call_2[0]
    assert response2.command == 'DEACTIVATE'

    await pool.close()
    server_process1.kill()
    server_process2.kill()
    server_process1.join()
    server_process2.join()