def test_grpc_data_runtime(mocker): args = set_pea_parser().parse_args([]) handle_mock = mocker.Mock() def start_runtime(args, handle_mock): with GRPCDataRuntime(args) as runtime: runtime._data_request_handler.handle = handle_mock runtime.run_forever() runtime_thread = Thread( target=start_runtime, args=( args, handle_mock, ), ) runtime_thread.start() assert GRPCDataRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port_in}', shutdown_event=Event()) Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}', is_async=False).Call(_create_test_data_message()) time.sleep(0.1) handle_mock.assert_called() GRPCDataRuntime.cancel(f'{args.host}:{args.port_in}') runtime_thread.join() assert not GRPCDataRuntime.is_ready(f'{args.host}:{args.port_in}')
async def test_send_static_ctrl_msg(mocker): # AsyncMock does not seem to exist in python 3.7, this is a manual workaround receive_cb = mocker.Mock() async def mock_wrapper(msg): receive_cb() args = set_pea_parser().parse_args([]) grpclet = Grpclet(args=args, message_callback=mock_wrapper) asyncio.get_event_loop().create_task(grpclet.start()) receive_cb.assert_not_called() while True: try: def send_status(): return Grpclet.send_ctrl_msg( pod_address=f'{args.host}:{args.port_in}', command='STATUS') await asyncio.get_event_loop().run_in_executor(None, send_status) break except RpcError: await asyncio.sleep(0.1) receive_cb.assert_called() await grpclet.close(None)
def test_grpc_data_runtime(mocker): args = set_pea_parser().parse_args([]) handle_mock = multiprocessing.Event() cancel_event = multiprocessing.Event() def start_runtime(args, handle_mock, cancel_event): with GRPCDataRuntime(args, cancel_event) as runtime: runtime._data_request_handler.handle = ( lambda *args, **kwargs: handle_mock.set() ) runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, handle_mock, cancel_event), daemon=True, ) runtime_thread.start() assert GRPCDataRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port_in}', shutdown_event=Event() ) Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}', is_async=False).Call( _create_test_data_message() ) time.sleep(0.1) assert handle_mock.is_set() GRPCDataRuntime.cancel(cancel_event) runtime_thread.join() assert not GRPCDataRuntime.is_ready(f'{args.host}:{args.port_in}')
def test_grpc_data_runtime_waits_for_pending_messages_shutdown(close_method): args = set_pea_parser().parse_args([]) cancel_event = multiprocessing.Event() handler_closed_event = multiprocessing.Event() slow_executor_block_time = 1.0 pending_requests = 3 sent_queue = multiprocessing.Queue() def start_runtime(args, cancel_event, sent_queue, handler_closed_event): with GRPCDataRuntime(args, cancel_event) as runtime: runtime._data_request_handler.handle = lambda *args, **kwargs: time.sleep( slow_executor_block_time ) runtime._data_request_handler.close = ( lambda *args, **kwargs: handler_closed_event.set() ) async def mock(msg): sent_queue.put('') runtime._grpclet.send_message = mock runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event, sent_queue, handler_closed_event), daemon=True, ) runtime_thread.start() assert GRPCDataRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port_in}', shutdown_event=Event() ) request_start_time = time.time() for i in range(pending_requests): Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}', is_async=False).Call( _create_test_data_message() ) time.sleep(0.1) if close_method == 'TERMINATE': runtime_thread.terminate() else: GRPCDataRuntime.cancel(cancel_event) assert not handler_closed_event.is_set() runtime_thread.join() assert ( time.time() - request_start_time >= slow_executor_block_time * pending_requests ) assert sent_queue.qsize() == pending_requests assert handler_closed_event.is_set() assert not GRPCDataRuntime.is_ready(f'{args.host}:{args.port_in}')
def cancel( control_address: str, **kwargs, ): """ Cancel this runtime by sending a TERMINATE control message :param control_address: the address where the control message needs to be sent :param kwargs: extra keyword arguments """ try: Grpclet.send_ctrl_msg(control_address, 'TERMINATE') except RpcError: # TERMINATE can fail if the the runtime dies before sending the return value pass
async def test_send_receive(mocker): # AsyncMock does not seem to exist in python 3.7, this is a manual workaround receive_cb = mocker.Mock() async def mock_wrapper(msg): receive_cb() args = set_pea_parser().parse_args([]) grpclet = Grpclet(args=args, message_callback=mock_wrapper) asyncio.get_event_loop().create_task(grpclet.start()) receive_cb.assert_not_called() await grpclet.send_message(_create_msg(args)) await asyncio.sleep(0.1) receive_cb.assert_called() await grpclet.close(None)
def __init__(self, args: argparse.Namespace, **kwargs): """Initialize grpc and data request handling. :param args: args from CLI :param kwargs: extra keyword arguments """ super().__init__(args, **kwargs) self._id = random_identity() self._loop = get_or_reuse_loop() self._last_active_time = time.perf_counter() self._pending_msgs = defaultdict(list) # type: Dict[str, List[Message]] self._partial_requests = None self._pending_tasks = [] self._static_routing_table = args.static_routing_table self._data_request_handler = DataRequestHandler(args, self.logger) self._grpclet = Grpclet( args=self.args, message_callback=self._callback, logger=self.logger, )
def test_executor_runtimes(signal, tmpdir, grpc_data_requests): import time args = set_pea_parser().parse_args([]) def run(args): args.uses = { 'jtype': 'DummyExecutor', 'with': { 'dir': str(tmpdir) }, 'metas': { 'workspace': str(tmpdir) }, } args.grpc_data_requests = grpc_data_requests executor_native(args) process = multiprocessing.Process(target=run, args=(args, )) process.start() time.sleep(0.5) if grpc_data_requests: Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}', is_async=False).Call( _create_test_data_message()) else: socket = zmq.Context().socket(zmq.PUSH) socket.connect(f'tcp://localhost:{args.port_in}') socket.send_multipart(_create_test_data_message().dump()) time.sleep(0.1) os.kill(process.pid, signal) process.join() with open(f'{tmpdir}/test.txt', 'r') as fp: output = fp.read() split = output.split(';') assert split[0] == 'proper close' assert split[1] == '1'
async def test_send_non_blocking(mocker): receive_cb = mocker.Mock() async def blocking_cb(msg): receive_cb() time.sleep(1.0) return msg args = set_pea_parser().parse_args([]) grpclet = Grpclet(args=args, message_callback=blocking_cb) asyncio.get_event_loop().create_task(grpclet.start()) receive_cb.assert_not_called() await grpclet.send_message(_create_msg(args)) await asyncio.sleep(0.1) assert receive_cb.call_count == 1 await grpclet.send_message(_create_msg(args)) await asyncio.sleep(0.1) assert receive_cb.call_count == 2 await grpclet.close(None)
def is_ready(ctrl_address: str, **kwargs) -> bool: """ Check if status is ready. :param ctrl_address: the address where the control message needs to be sent :param kwargs: extra keyword arguments :return: True if status is ready else False. """ try: response = Grpclet.send_ctrl_msg(ctrl_address, 'STATUS') except RpcError: return False return True
class GRPCDataRuntime(BaseRuntime, ABC): """Runtime procedure leveraging :class:`Grpclet` for sending DataRequests""" def __init__(self, args: argparse.Namespace, **kwargs): """Initialize grpc and data request handling. :param args: args from CLI :param kwargs: extra keyword arguments """ super().__init__(args, **kwargs) self._id = random_identity() self._loop = get_or_reuse_loop() self._last_active_time = time.perf_counter() self._pending_msgs = defaultdict(list) # type: Dict[str, List[Message]] self._partial_requests = None self._pending_tasks = [] self._static_routing_table = args.static_routing_table self._data_request_handler = DataRequestHandler(args, self.logger) self._grpclet = Grpclet( args=self.args, message_callback=self._callback, logger=self.logger, ) def _update_pending_tasks(self): self._pending_tasks = [task for task in self._pending_tasks if not task.done()] def run_forever(self): """Start the `Grpclet`.""" self._grpclet_task = self._loop.create_task(self._grpclet.start()) try: self._loop.run_until_complete(self._grpclet_task) except asyncio.CancelledError: self.logger.warning('Grpclet task was cancelled') def teardown(self): """Close the `Grpclet` and `DataRequestHandler`.""" self.logger.debug('Teardown GRPCDataRuntime') self._data_request_handler.close() start = time.time() while self._pending_tasks and time.time() - start < 1.0: self._update_pending_tasks() time.sleep(0.1) self._loop.stop() self._loop.close() super().teardown() async def _close_grpclet(self): await self._grpclet.close() self._grpclet_task.cancel() @staticmethod def get_control_address(**kwargs): """ Does return None, exists for keeping interface compatible with ZEDRuntime :param kwargs: extra keyword arguments :returns: None """ return None @staticmethod def is_ready(ctrl_address: str, **kwargs) -> bool: """ Check if status is ready. :param ctrl_address: the address where the control message needs to be sent :param kwargs: extra keyword arguments :return: True if status is ready else False. """ try: response = Grpclet.send_ctrl_msg(ctrl_address, 'STATUS') except RpcError: return False return True @staticmethod def activate( **kwargs, ): """ Does nothing :param kwargs: extra keyword arguments """ pass @staticmethod def cancel( control_address: str, **kwargs, ): """ Cancel this runtime by sending a TERMINATE control message :param control_address: the address where the control message needs to be sent :param kwargs: extra keyword arguments """ try: Grpclet.send_ctrl_msg(control_address, 'TERMINATE') except RpcError: # TERMINATE can fail if the the runtime dies before sending the return value pass @staticmethod def wait_for_ready_or_shutdown( timeout: Optional[float], ctrl_address: str, shutdown_event: Union[multiprocessing.Event, threading.Event], **kwargs, ): """ Check if the runtime has successfully started :param timeout: The time to wait before readiness or failure is determined :param ctrl_address: the address where the control message needs to be sent :param shutdown_event: the multiprocessing event to detect if the process failed :param kwargs: extra keyword arguments :return: True if is ready or it needs to be shutdown """ timeout_ns = 1000000000 * timeout if timeout else None now = time.time_ns() while timeout_ns is None or time.time_ns() - now < timeout_ns: if shutdown_event.is_set() or GRPCDataRuntime.is_ready(ctrl_address): return True time.sleep(0.1) return False async def _callback(self, msg: Message) -> None: try: msg = self._post_hook(self._handle(self._pre_hook(msg))) if msg.is_data_request: asyncio.create_task(self._grpclet.send_message(msg)) except RuntimeTerminated: # this is the proper way to end when a terminate signal is sent self._pending_tasks.append(asyncio.create_task(self._close_grpclet())) except KeyboardInterrupt as kbex: self.logger.debug(f'{kbex!r} causes the breaking from the event loop') self._pending_tasks.append(asyncio.create_task(self._close_grpclet())) except (SystemError) as ex: # save executor self.logger.debug(f'{ex!r} causes the breaking from the event loop') self._pending_tasks.append(asyncio.create_task(self._close_grpclet())) except NoExplicitMessage: # silent and do not propagate message anymore # 1. wait partial message to be finished pass except (RuntimeError, Exception, ChainedPodException) as ex: if self.args.on_error_strategy == OnErrorStrategy.THROW_EARLY: raise if isinstance(ex, ChainedPodException): # the error is print from previous pod, no need to show it again # hence just add exception and propagate further # please do NOT add logger.error here! msg.add_exception() else: msg.add_exception(ex, executor=self._data_request_handler._executor) self.logger.error( f'{ex!r}' + f'\n add "--quiet-error" to suppress the exception details' if not self.args.quiet_error else '', exc_info=not self.args.quiet_error, ) if msg.is_data_request: asyncio.create_task(self._grpclet.send_message(msg)) asyncio.create_task(self._grpclet.send_message(msg)) def _handle(self, msg: Message) -> Message: """Register the current message to this pea, so that all message-related properties are up-to-date, including :attr:`request`, :attr:`prev_requests`, :attr:`message`, :attr:`prev_messages`. And then call the executor to handle this message if its envelope's status is not ERROR, else skip handling of message. .. note:: Handle does not handle explicitly message because it may wait for different messages when different parts are expected :param msg: received message :return: the transformed message. """ # skip executor for non-DataRequest if msg.envelope.request_type != 'DataRequest': if msg.request.command == 'TERMINATE': raise RuntimeTerminated() self.logger.debug(f'skip executor: not data request') return msg req_id = msg.envelope.request_id num_expected_parts = self._get_expected_parts(msg) self._data_request_handler.handle( msg=msg, partial_requests=[m.request for m in self._pending_msgs[req_id]] if num_expected_parts > 1 else None, peapod_name=self.name, ) return msg def _get_expected_parts(self, msg): if msg.is_data_request: if not self._static_routing_table: graph = RoutingTable(msg.envelope.routing_table) return graph.active_target_pod.expected_parts else: return self.args.num_part else: return 1 def _pre_hook(self, msg: Message) -> Message: """ Pre-hook function, what to do after first receiving the message. :param msg: received message :return: `Message` """ msg.add_route(self.name, self._id) expected_parts = self._get_expected_parts(msg) req_id = msg.envelope.request_id if expected_parts > 1: self._pending_msgs[req_id].append(msg) num_partial_requests = len(self._pending_msgs[req_id]) if self.logger.debug_enabled: self._log_info_msg( msg, f'({num_partial_requests}/{expected_parts} parts)' if expected_parts > 1 else '', ) if expected_parts > 1 and expected_parts > num_partial_requests: # NOTE: reduce priority is higher than chain exception # otherwise a reducer will lose its function when earlier pods raise exception raise NoExplicitMessage if ( msg.envelope.status.code == jina_pb2.StatusProto.ERROR and self.args.on_error_strategy >= OnErrorStrategy.SKIP_HANDLE ): raise ChainedPodException return msg def _log_info_msg(self, msg, part_str): info_msg = f'recv {msg.envelope.request_type} ' req_type = msg.envelope.request_type if req_type == 'DataRequest': info_msg += ( f'({msg.envelope.header.exec_endpoint}) - ({msg.envelope.request_id}) ' ) elif req_type == 'ControlRequest': info_msg += f'({msg.request.command}) ' info_msg += f'{part_str} from {msg.colored_route}' self.logger.debug(info_msg) def _post_hook(self, msg: Message) -> Message: """ Post-hook function, what to do before handing out the message. :param msg: the transformed message :return: `Message` """ # do NOT access `msg.request.*` in the _pre_hook, as it will trigger the deserialization # all meta information should be stored and accessed via `msg.envelope` self._last_active_time = time.perf_counter() if self._get_expected_parts(msg) > 1: msgs = self._pending_msgs.pop(msg.envelope.request_id) msg.merge_envelope_from(msgs) msg.update_timestamp() return msg
def send_status(): return Grpclet.send_ctrl_msg( pod_address=f'{args.host}:{args.port_in}', command='STATUS')