Пример #1
0
    async def _handle_data_request(
            self, requests: List[DataRequest],
            endpoint: Optional[str]) -> Tuple[DataRequest, Dict]:
        self.logger.debug(f'recv {len(requests)} DataRequest(s)')

        DataRequestHandler.merge_routes(requests)

        uses_before_metadata = None
        if self.uses_before_address:
            (
                response,
                uses_before_metadata,
            ) = await self.connection_pool.send_requests_once(
                requests, deployment='uses_before', timeout=self.timeout_send)
            requests = [response]

        worker_send_tasks = self.connection_pool.send_requests(
            requests=requests,
            deployment=self._deployment_name,
            polling_type=self._polling[endpoint],
            timeout=self.timeout_send,
        )

        worker_results = await asyncio.gather(*worker_send_tasks)

        if len(worker_results) == 0:
            raise RuntimeError(
                f'Head {self.name} did not receive a response when sending message to worker pods'
            )

        worker_results, metadata = zip(*worker_results)

        response_request = worker_results[0]
        uses_after_metadata = None
        if self.uses_after_address:
            (
                response_request,
                uses_after_metadata,
            ) = await self.connection_pool.send_requests_once(
                worker_results,
                deployment='uses_after',
                timeout=self.timeout_send)
        elif len(worker_results) > 1 and self._reduce:
            DataRequestHandler.reduce_requests(worker_results)
        elif len(worker_results) > 1 and not self._reduce:
            # worker returned multiple responsed, but the head is configured to skip reduction
            # just concatenate the docs in this case
            response_request.data.docs = DataRequestHandler.get_docs_from_request(
                requests, field='docs')

        merged_metadata = self._merge_metadata(metadata, uses_after_metadata,
                                               uses_before_metadata)

        return response_request, merged_metadata
Пример #2
0
    async def _handle_data_request(
            self, requests: List[DataRequest],
            endpoint: Optional[str]) -> Tuple[DataRequest, Dict]:
        self.logger.debug(f'recv {len(requests)} DataRequest(s)')

        DataRequestHandler.merge_routes(requests)

        uses_before_metadata = None
        if self.uses_before_address:
            (
                response,
                uses_before_metadata,
            ) = await self.connection_pool.send_requests_once(
                requests, deployment='uses_before')
            requests = [response]
        elif len(requests) > 1 and not self._has_uses:
            requests = [DataRequestHandler.reduce_requests(requests)]

        worker_send_tasks = self.connection_pool.send_requests(
            requests=requests,
            deployment=self._deployment_name,
            polling_type=self._polling[endpoint],
        )

        worker_results = await asyncio.gather(*worker_send_tasks)

        if len(worker_results) == 0:
            raise RuntimeError(
                f'Head {self.name} did not receive a response when sending message to worker pods'
            )

        worker_results, metadata = zip(*worker_results)

        response_request = worker_results[0]
        uses_after_metadata = None
        if self.uses_after_address:
            (
                response_request,
                uses_after_metadata,
            ) = await self.connection_pool.send_requests_once(
                worker_results, deployment='uses_after')
        elif len(worker_results) > 1:
            DataRequestHandler.reduce_requests(worker_results)

        merged_metadata = self._merge_metadata(metadata, uses_after_metadata,
                                               uses_before_metadata)

        return response_request, merged_metadata
Пример #3
0
    def __init__(
        self,
        args: argparse.Namespace,
        cancel_event: Optional[Union['asyncio.Event', 'multiprocessing.Event',
                                     'threading.Event']] = None,
        **kwargs,
    ):
        """Initialize grpc and data request handling.
        :param args: args from CLI
        :param cancel_event: the cancel event used to wait for canceling
        :param kwargs: keyword args
        """
        super().__init__(args, cancel_event, **kwargs)

        # Keep this initialization order, otherwise readiness check is not valid
        self._data_request_handler = DataRequestHandler(args, self.logger)
Пример #4
0
    async def _async_setup_grpc_server(self):
        """
        Start the DataRequestHandler and wait for the GRPC server to start
        """

        # Keep this initialization order
        # otherwise readiness check is not valid
        # The DataRequestHandler needs to be started BEFORE the grpc server
        self._data_request_handler = DataRequestHandler(
            self.args, self.logger, self.metrics_registry
        )

        self._grpc_server = grpc.aio.server(
            options=[
                ('grpc.max_send_message_length', -1),
                ('grpc.max_receive_message_length', -1),
            ]
        )

        jina_pb2_grpc.add_JinaSingleDataRequestRPCServicer_to_server(
            self, self._grpc_server
        )
        jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self, self._grpc_server)

        jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server(
            self, self._grpc_server
        )
        jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self._grpc_server)
        service_names = (
            jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name,
            reflection.SERVICE_NAME,
        )
        # Mark all services as healthy.
        health_pb2_grpc.add_HealthServicer_to_server(
            self._health_servicer, self._grpc_server
        )

        for service in service_names:
            self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING)
        reflection.enable_server_reflection(service_names, self._grpc_server)
        bind_addr = f'0.0.0.0:{self.args.port}'
        self.logger.debug(f'start listening on {bind_addr}')
        self._grpc_server.add_insecure_port(bind_addr)
        await self._grpc_server.start()
Пример #5
0
        async def _wait_previous_and_send(
            self,
            request: DataRequest,
            previous_task: Optional[asyncio.Task],
            connection_pool: GrpcConnectionPool,
            endpoint: Optional[str],
            executor_endpoint_mapping: Optional[Dict] = None,
            target_executor_pattern: Optional[str] = None,
        ):
            # Check my condition and send request with the condition
            metadata = {}
            if previous_task is not None:
                result = await previous_task
                request, metadata = result[0], result[1]
            if metadata and 'is-error' in metadata:
                return request, metadata
            elif request is not None:
                self.parts_to_send.append(request)
                # this is a specific needs
                if len(self.parts_to_send) == self.number_of_parts:
                    self.start_time = datetime.utcnow()
                    if self._filter_condition is not None:
                        self._update_requests()
                    if self._reduce and len(self.parts_to_send) > 1:
                        self.parts_to_send = [
                            DataRequestHandler.reduce_requests(
                                self.parts_to_send)
                        ]

                    # avoid sending to executor which does not bind to this endpoint
                    if endpoint is not None and executor_endpoint_mapping is not None:
                        if (endpoint
                                not in executor_endpoint_mapping[self.name]
                                and __default_endpoint__
                                not in executor_endpoint_mapping[self.name]):
                            return request, metadata

                    if target_executor_pattern is not None and not re.match(
                            target_executor_pattern, self.name):
                        return request, metadata
                    # otherwise, send to executor and get response
                    try:
                        resp, metadata = await connection_pool.send_requests_once(
                            requests=self.parts_to_send,
                            deployment=self.name,
                            head=True,
                            endpoint=endpoint,
                            timeout=self._timeout_send,
                            retries=self._retries,
                        )
                    except InternalNetworkError as err:
                        self._handle_internalnetworkerror(err)

                    self.end_time = datetime.utcnow()
                    if metadata and 'is-error' in metadata:
                        self.status = resp.header.status
                    return resp, metadata

            return None, {}
Пример #6
0
async def test_aync_data_request_handler_new_docs(logger):
    args = set_pod_parser().parse_args(['--uses', 'AsyncNewDocsExecutor'])
    handler = DataRequestHandler(args, logger)
    req = list(
        request_generator(
            '/',
            DocumentArray([Document(text='input document')
                           for _ in range(10)])))[0]
    assert len(req.docs) == 10
    response = await handler.handle(requests=[req])

    assert len(response.docs) == 1
    assert response.docs[0].text == 'new document'
Пример #7
0
async def test_data_request_handler_change_docs_from_partial_requests(logger):
    NUM_PARTIAL_REQUESTS = 5
    args = set_pod_parser().parse_args(['--uses', 'MergeChangeDocsExecutor'])
    handler = DataRequestHandler(args, logger)

    partial_reqs = [
        list(
            request_generator(
                '/',
                DocumentArray(
                    [Document(text='input document') for _ in range(10)])))[0]
    ] * NUM_PARTIAL_REQUESTS
    assert len(partial_reqs) == 5
    assert len(partial_reqs[0].docs) == 10
    response = await handler.handle(requests=partial_reqs)

    assert len(response.docs) == 10 * NUM_PARTIAL_REQUESTS
    for doc in response.docs:
        assert doc.text == 'changed document'
Пример #8
0
class WorkerRuntime(AsyncNewLoopRuntime, ABC):
    """Runtime procedure leveraging :class:`Grpclet` for sending DataRequests"""
    def __init__(
        self,
        args: argparse.Namespace,
        cancel_event: Optional[Union['asyncio.Event', 'multiprocessing.Event',
                                     'threading.Event']] = None,
        **kwargs,
    ):
        """Initialize grpc and data request handling.
        :param args: args from CLI
        :param cancel_event: the cancel event used to wait for canceling
        :param kwargs: keyword args
        """
        super().__init__(args, cancel_event, **kwargs)

        # Keep this initialization order, otherwise readiness check is not valid
        self._data_request_handler = DataRequestHandler(args, self.logger)

    async def async_setup(self):
        """
        Wait for the GRPC server to start
        """
        self._grpc_server = grpc.aio.server(options=[
            ('grpc.max_send_message_length', -1),
            ('grpc.max_receive_message_length', -1),
        ])

        jina_pb2_grpc.add_JinaSingleDataRequestRPCServicer_to_server(
            self, self._grpc_server)
        jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(
            self, self._grpc_server)
        jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server(
            self, self._grpc_server)
        bind_addr = f'0.0.0.0:{self.args.port_in}'
        self.logger.debug(f'Start listening on {bind_addr}')
        self._grpc_server.add_insecure_port(bind_addr)
        await self._grpc_server.start()

    async def async_run_forever(self):
        """Block until the GRPC server is terminated """
        await self._grpc_server.wait_for_termination()

    async def async_cancel(self):
        """Stop the GRPC server"""
        self.logger.debug('Cancel WorkerRuntime')

        # 0.5 gives the runtime some time to complete outstanding responses
        # this should be handled better, 0.5 is a rather random number
        await self._grpc_server.stop(0.5)
        self.logger.debug('Stopped GRPC Server')

    async def async_teardown(self):
        """Close the data request handler"""
        await self.async_cancel()
        self._data_request_handler.close()

    async def process_single_data(self, request: DataRequest,
                                  context) -> DataRequest:
        """
        Process the received requests and return the result as a new request

        :param request: the data request to process
        :param context: grpc context
        :returns: the response request
        """
        return await self.process_data([request], context)

    async def process_data(self, requests: List[DataRequest],
                           context) -> DataRequest:
        """
        Process the received requests and return the result as a new request

        :param requests: the data requests to process
        :param context: grpc context
        :returns: the response request
        """
        try:
            if self.logger.debug_enabled:
                self._log_data_request(requests[0])

            return await self._data_request_handler.handle(requests=requests)
        except (RuntimeError, Exception) as ex:
            self.logger.error(
                f'{ex!r}' +
                f'\n add "--quiet-error" to suppress the exception details'
                if not self.args.quiet_error else '',
                exc_info=not self.args.quiet_error,
            )

            requests[0].add_exception(ex, self._data_request_handler._executor)
            context.set_trailing_metadata((('is-error', 'true'), ))
            return requests[0]

    async def process_control(self, request: ControlRequest,
                              *args) -> ControlRequest:
        """
        Process the received control request and return the same request

        :param request: the control request to process
        :param args: additional arguments in the grpc call, ignored
        :returns: the input request
        """
        try:
            if self.logger.debug_enabled:
                self._log_control_request(request)

            if request.command == 'STATUS':
                pass
            else:
                raise RuntimeError(
                    f'WorkerRuntime received unsupported ControlRequest command {request.command}'
                )
        except (RuntimeError, Exception) as ex:
            self.logger.error(
                f'{ex!r}' +
                f'\n add "--quiet-error" to suppress the exception details'
                if not self.args.quiet_error else '',
                exc_info=not self.args.quiet_error,
            )

            request.add_exception(ex, self._data_request_handler._executor)
        return request
Пример #9
0
class WorkerRuntime(AsyncNewLoopRuntime, ABC):
    """Runtime procedure leveraging :class:`Grpclet` for sending DataRequests"""

    def __init__(
        self,
        args: argparse.Namespace,
        **kwargs,
    ):
        """Initialize grpc and data request handling.
        :param args: args from CLI
        :param kwargs: keyword args
        """
        self._health_servicer = health.HealthServicer(experimental_non_blocking=True)
        super().__init__(args, **kwargs)

    async def async_setup(self):
        """
        Start the DataRequestHandler and wait for the GRPC and Monitoring servers to start
        """
        if self.metrics_registry:
            with ImportExtensions(
                required=True,
                help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Summary

            self._summary_time = (
                Summary(
                    'receiving_request_seconds',
                    'Time spent processing request',
                    registry=self.metrics_registry,
                    namespace='jina',
                    labelnames=('runtime_name',),
                )
                .labels(self.args.name)
                .time()
            )
        else:
            self._summary_time = contextlib.nullcontext()

        await self._async_setup_grpc_server()

    async def _async_setup_grpc_server(self):
        """
        Start the DataRequestHandler and wait for the GRPC server to start
        """

        # Keep this initialization order
        # otherwise readiness check is not valid
        # The DataRequestHandler needs to be started BEFORE the grpc server
        self._data_request_handler = DataRequestHandler(
            self.args, self.logger, self.metrics_registry
        )

        self._grpc_server = grpc.aio.server(
            options=[
                ('grpc.max_send_message_length', -1),
                ('grpc.max_receive_message_length', -1),
            ]
        )

        jina_pb2_grpc.add_JinaSingleDataRequestRPCServicer_to_server(
            self, self._grpc_server
        )
        jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self, self._grpc_server)

        jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server(
            self, self._grpc_server
        )
        jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self._grpc_server)
        service_names = (
            jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name,
            reflection.SERVICE_NAME,
        )
        # Mark all services as healthy.
        health_pb2_grpc.add_HealthServicer_to_server(
            self._health_servicer, self._grpc_server
        )

        for service in service_names:
            self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING)
        reflection.enable_server_reflection(service_names, self._grpc_server)
        bind_addr = f'0.0.0.0:{self.args.port}'
        self.logger.debug(f'start listening on {bind_addr}')
        self._grpc_server.add_insecure_port(bind_addr)
        await self._grpc_server.start()

    async def async_run_forever(self):
        """Block until the GRPC server is terminated"""
        await self._grpc_server.wait_for_termination()

    async def async_cancel(self):
        """Stop the GRPC server"""
        self.logger.debug('cancel WorkerRuntime')

        # 0.5 gives the runtime some time to complete outstanding responses
        # this should be handled better, 1.0 is a rather random number
        await self._grpc_server.stop(1.0)
        self.logger.debug('stopped GRPC Server')

    async def async_teardown(self):
        """Close the data request handler"""
        self._health_servicer.enter_graceful_shutdown()
        await self.async_cancel()
        self._data_request_handler.close()

    async def process_single_data(self, request: DataRequest, context) -> DataRequest:
        """
        Process the received requests and return the result as a new request

        :param request: the data request to process
        :param context: grpc context
        :returns: the response request
        """
        return await self.process_data([request], context)

    async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
        """
        Process the the call requested and return the list of Endpoints exposed by the Executor wrapped inside this Runtime

        :param empty: The service expects an empty protobuf message
        :param context: grpc context
        :returns: the response request
        """
        endpointsProto = jina_pb2.EndpointsProto()
        endpointsProto.endpoints.extend(
            list(self._data_request_handler._executor.requests.keys())
        )
        return endpointsProto

    async def process_data(self, requests: List[DataRequest], context) -> DataRequest:
        """
        Process the received requests and return the result as a new request

        :param requests: the data requests to process
        :param context: grpc context
        :returns: the response request
        """

        with self._summary_time:
            try:
                if self.logger.debug_enabled:
                    self._log_data_request(requests[0])

                return await self._data_request_handler.handle(requests=requests)
            except (RuntimeError, Exception) as ex:
                self.logger.error(
                    f'{ex!r}'
                    + f'\n add "--quiet-error" to suppress the exception details'
                    if not self.args.quiet_error
                    else '',
                    exc_info=not self.args.quiet_error,
                )

                requests[0].add_exception(ex, self._data_request_handler._executor)
                context.set_trailing_metadata((('is-error', 'true'),))
                return requests[0]

    async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
        """
        Process the the call requested and return the JinaInfo of the Runtime

        :param empty: The service expects an empty protobuf message
        :param context: grpc context
        :returns: the response request
        """
        infoProto = jina_pb2.JinaInfoProto()
        version, env_info = get_full_version()
        for k, v in version.items():
            infoProto.jina[k] = str(v)
        for k, v in env_info.items():
            infoProto.envs[k] = str(v)
        return infoProto