Beispiel #1
0
def test_grpc_data_runtime(mocker):
    args = set_pea_parser().parse_args([])
    handle_mock = mocker.Mock()

    def start_runtime(args, handle_mock):
        with GRPCDataRuntime(args) as runtime:
            runtime._data_request_handler.handle = handle_mock
            runtime.run_forever()

    runtime_thread = Thread(
        target=start_runtime,
        args=(
            args,
            handle_mock,
        ),
    )
    runtime_thread.start()

    assert GRPCDataRuntime.wait_for_ready_or_shutdown(
        timeout=5.0,
        ctrl_address=f'{args.host}:{args.port_in}',
        shutdown_event=Event())

    Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}',
                              is_async=False).Call(_create_test_data_message())
    time.sleep(0.1)
    handle_mock.assert_called()

    GRPCDataRuntime.cancel(f'{args.host}:{args.port_in}')
    runtime_thread.join()

    assert not GRPCDataRuntime.is_ready(f'{args.host}:{args.port_in}')
Beispiel #2
0
async def test_send_static_ctrl_msg(mocker):
    # AsyncMock does not seem to exist in python 3.7, this is a manual workaround
    receive_cb = mocker.Mock()

    async def mock_wrapper(msg):
        receive_cb()

    args = set_pea_parser().parse_args([])
    grpclet = Grpclet(args=args, message_callback=mock_wrapper)
    asyncio.get_event_loop().create_task(grpclet.start())

    receive_cb.assert_not_called()

    while True:
        try:

            def send_status():
                return Grpclet.send_ctrl_msg(
                    pod_address=f'{args.host}:{args.port_in}',
                    command='STATUS')

            await asyncio.get_event_loop().run_in_executor(None, send_status)
            break
        except RpcError:
            await asyncio.sleep(0.1)

    receive_cb.assert_called()
    await grpclet.close(None)
def test_grpc_data_runtime(mocker):
    args = set_pea_parser().parse_args([])
    handle_mock = multiprocessing.Event()

    cancel_event = multiprocessing.Event()

    def start_runtime(args, handle_mock, cancel_event):
        with GRPCDataRuntime(args, cancel_event) as runtime:
            runtime._data_request_handler.handle = (
                lambda *args, **kwargs: handle_mock.set()
            )
            runtime.run_forever()

    runtime_thread = Process(
        target=start_runtime,
        args=(args, handle_mock, cancel_event),
        daemon=True,
    )
    runtime_thread.start()

    assert GRPCDataRuntime.wait_for_ready_or_shutdown(
        timeout=5.0, ctrl_address=f'{args.host}:{args.port_in}', shutdown_event=Event()
    )

    Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}', is_async=False).Call(
        _create_test_data_message()
    )
    time.sleep(0.1)
    assert handle_mock.is_set()

    GRPCDataRuntime.cancel(cancel_event)
    runtime_thread.join()

    assert not GRPCDataRuntime.is_ready(f'{args.host}:{args.port_in}')
def test_grpc_data_runtime_waits_for_pending_messages_shutdown(close_method):
    args = set_pea_parser().parse_args([])

    cancel_event = multiprocessing.Event()
    handler_closed_event = multiprocessing.Event()
    slow_executor_block_time = 1.0
    pending_requests = 3
    sent_queue = multiprocessing.Queue()

    def start_runtime(args, cancel_event, sent_queue, handler_closed_event):
        with GRPCDataRuntime(args, cancel_event) as runtime:
            runtime._data_request_handler.handle = lambda *args, **kwargs: time.sleep(
                slow_executor_block_time
            )
            runtime._data_request_handler.close = (
                lambda *args, **kwargs: handler_closed_event.set()
            )

            async def mock(msg):
                sent_queue.put('')

            runtime._grpclet.send_message = mock

            runtime.run_forever()

    runtime_thread = Process(
        target=start_runtime,
        args=(args, cancel_event, sent_queue, handler_closed_event),
        daemon=True,
    )
    runtime_thread.start()

    assert GRPCDataRuntime.wait_for_ready_or_shutdown(
        timeout=5.0, ctrl_address=f'{args.host}:{args.port_in}', shutdown_event=Event()
    )

    request_start_time = time.time()
    for i in range(pending_requests):
        Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}', is_async=False).Call(
            _create_test_data_message()
        )
    time.sleep(0.1)

    if close_method == 'TERMINATE':
        runtime_thread.terminate()
    else:
        GRPCDataRuntime.cancel(cancel_event)
    assert not handler_closed_event.is_set()
    runtime_thread.join()

    assert (
        time.time() - request_start_time >= slow_executor_block_time * pending_requests
    )
    assert sent_queue.qsize() == pending_requests
    assert handler_closed_event.is_set()

    assert not GRPCDataRuntime.is_ready(f'{args.host}:{args.port_in}')
Beispiel #5
0
    def cancel(
        control_address: str,
        **kwargs,
    ):
        """
        Cancel this runtime by sending a TERMINATE control message

        :param control_address: the address where the control message needs to be sent
        :param kwargs: extra keyword arguments
        """
        try:
            Grpclet.send_ctrl_msg(control_address, 'TERMINATE')
        except RpcError:
            # TERMINATE can fail if the the runtime dies before sending the return value
            pass
Beispiel #6
0
async def test_send_receive(mocker):
    # AsyncMock does not seem to exist in python 3.7, this is a manual workaround
    receive_cb = mocker.Mock()

    async def mock_wrapper(msg):
        receive_cb()

    args = set_pea_parser().parse_args([])
    grpclet = Grpclet(args=args, message_callback=mock_wrapper)
    asyncio.get_event_loop().create_task(grpclet.start())

    receive_cb.assert_not_called()

    await grpclet.send_message(_create_msg(args))
    await asyncio.sleep(0.1)
    receive_cb.assert_called()

    await grpclet.close(None)
Beispiel #7
0
    def __init__(self, args: argparse.Namespace, **kwargs):
        """Initialize grpc and data request handling.
        :param args: args from CLI
        :param kwargs: extra keyword arguments
        """
        super().__init__(args, **kwargs)
        self._id = random_identity()
        self._loop = get_or_reuse_loop()
        self._last_active_time = time.perf_counter()

        self._pending_msgs = defaultdict(list)  # type: Dict[str, List[Message]]
        self._partial_requests = None
        self._pending_tasks = []
        self._static_routing_table = args.static_routing_table

        self._data_request_handler = DataRequestHandler(args, self.logger)
        self._grpclet = Grpclet(
            args=self.args,
            message_callback=self._callback,
            logger=self.logger,
        )
Beispiel #8
0
def test_executor_runtimes(signal, tmpdir, grpc_data_requests):
    import time

    args = set_pea_parser().parse_args([])

    def run(args):

        args.uses = {
            'jtype': 'DummyExecutor',
            'with': {
                'dir': str(tmpdir)
            },
            'metas': {
                'workspace': str(tmpdir)
            },
        }
        args.grpc_data_requests = grpc_data_requests
        executor_native(args)

    process = multiprocessing.Process(target=run, args=(args, ))
    process.start()
    time.sleep(0.5)

    if grpc_data_requests:
        Grpclet._create_grpc_stub(f'{args.host}:{args.port_in}',
                                  is_async=False).Call(
                                      _create_test_data_message())
    else:
        socket = zmq.Context().socket(zmq.PUSH)
        socket.connect(f'tcp://localhost:{args.port_in}')
        socket.send_multipart(_create_test_data_message().dump())
    time.sleep(0.1)

    os.kill(process.pid, signal)
    process.join()
    with open(f'{tmpdir}/test.txt', 'r') as fp:
        output = fp.read()
    split = output.split(';')
    assert split[0] == 'proper close'
    assert split[1] == '1'
Beispiel #9
0
async def test_send_non_blocking(mocker):
    receive_cb = mocker.Mock()

    async def blocking_cb(msg):
        receive_cb()
        time.sleep(1.0)
        return msg

    args = set_pea_parser().parse_args([])
    grpclet = Grpclet(args=args, message_callback=blocking_cb)
    asyncio.get_event_loop().create_task(grpclet.start())

    receive_cb.assert_not_called()
    await grpclet.send_message(_create_msg(args))

    await asyncio.sleep(0.1)
    assert receive_cb.call_count == 1
    await grpclet.send_message(_create_msg(args))
    await asyncio.sleep(0.1)
    assert receive_cb.call_count == 2

    await grpclet.close(None)
Beispiel #10
0
    def is_ready(ctrl_address: str, **kwargs) -> bool:
        """
        Check if status is ready.

        :param ctrl_address: the address where the control message needs to be sent
        :param kwargs: extra keyword arguments

        :return: True if status is ready else False.
        """

        try:
            response = Grpclet.send_ctrl_msg(ctrl_address, 'STATUS')
        except RpcError:
            return False

        return True
Beispiel #11
0
class GRPCDataRuntime(BaseRuntime, ABC):
    """Runtime procedure leveraging :class:`Grpclet` for sending DataRequests"""

    def __init__(self, args: argparse.Namespace, **kwargs):
        """Initialize grpc and data request handling.
        :param args: args from CLI
        :param kwargs: extra keyword arguments
        """
        super().__init__(args, **kwargs)
        self._id = random_identity()
        self._loop = get_or_reuse_loop()
        self._last_active_time = time.perf_counter()

        self._pending_msgs = defaultdict(list)  # type: Dict[str, List[Message]]
        self._partial_requests = None
        self._pending_tasks = []
        self._static_routing_table = args.static_routing_table

        self._data_request_handler = DataRequestHandler(args, self.logger)
        self._grpclet = Grpclet(
            args=self.args,
            message_callback=self._callback,
            logger=self.logger,
        )

    def _update_pending_tasks(self):
        self._pending_tasks = [task for task in self._pending_tasks if not task.done()]

    def run_forever(self):
        """Start the `Grpclet`."""
        self._grpclet_task = self._loop.create_task(self._grpclet.start())
        try:
            self._loop.run_until_complete(self._grpclet_task)
        except asyncio.CancelledError:
            self.logger.warning('Grpclet task was cancelled')

    def teardown(self):
        """Close the `Grpclet` and `DataRequestHandler`."""
        self.logger.debug('Teardown GRPCDataRuntime')

        self._data_request_handler.close()
        start = time.time()
        while self._pending_tasks and time.time() - start < 1.0:
            self._update_pending_tasks()
            time.sleep(0.1)
        self._loop.stop()
        self._loop.close()

        super().teardown()

    async def _close_grpclet(self):
        await self._grpclet.close()
        self._grpclet_task.cancel()

    @staticmethod
    def get_control_address(**kwargs):
        """
        Does return None, exists for keeping interface compatible with ZEDRuntime

        :param kwargs: extra keyword arguments
        :returns: None
        """
        return None

    @staticmethod
    def is_ready(ctrl_address: str, **kwargs) -> bool:
        """
        Check if status is ready.

        :param ctrl_address: the address where the control message needs to be sent
        :param kwargs: extra keyword arguments

        :return: True if status is ready else False.
        """

        try:
            response = Grpclet.send_ctrl_msg(ctrl_address, 'STATUS')
        except RpcError:
            return False

        return True

    @staticmethod
    def activate(
        **kwargs,
    ):
        """
        Does nothing
        :param kwargs: extra keyword arguments
        """
        pass

    @staticmethod
    def cancel(
        control_address: str,
        **kwargs,
    ):
        """
        Cancel this runtime by sending a TERMINATE control message

        :param control_address: the address where the control message needs to be sent
        :param kwargs: extra keyword arguments
        """
        try:
            Grpclet.send_ctrl_msg(control_address, 'TERMINATE')
        except RpcError:
            # TERMINATE can fail if the the runtime dies before sending the return value
            pass

    @staticmethod
    def wait_for_ready_or_shutdown(
        timeout: Optional[float],
        ctrl_address: str,
        shutdown_event: Union[multiprocessing.Event, threading.Event],
        **kwargs,
    ):
        """
        Check if the runtime has successfully started

        :param timeout: The time to wait before readiness or failure is determined
        :param ctrl_address: the address where the control message needs to be sent
        :param shutdown_event: the multiprocessing event to detect if the process failed
        :param kwargs: extra keyword arguments
        :return: True if is ready or it needs to be shutdown
        """
        timeout_ns = 1000000000 * timeout if timeout else None
        now = time.time_ns()
        while timeout_ns is None or time.time_ns() - now < timeout_ns:
            if shutdown_event.is_set() or GRPCDataRuntime.is_ready(ctrl_address):
                return True
            time.sleep(0.1)

        return False

    async def _callback(self, msg: Message) -> None:
        try:
            msg = self._post_hook(self._handle(self._pre_hook(msg)))
            if msg.is_data_request:
                asyncio.create_task(self._grpclet.send_message(msg))
        except RuntimeTerminated:
            # this is the proper way to end when a terminate signal is sent
            self._pending_tasks.append(asyncio.create_task(self._close_grpclet()))
        except KeyboardInterrupt as kbex:
            self.logger.debug(f'{kbex!r} causes the breaking from the event loop')
            self._pending_tasks.append(asyncio.create_task(self._close_grpclet()))
        except (SystemError) as ex:
            # save executor
            self.logger.debug(f'{ex!r} causes the breaking from the event loop')
            self._pending_tasks.append(asyncio.create_task(self._close_grpclet()))
        except NoExplicitMessage:
            # silent and do not propagate message anymore
            # 1. wait partial message to be finished
            pass
        except (RuntimeError, Exception, ChainedPodException) as ex:
            if self.args.on_error_strategy == OnErrorStrategy.THROW_EARLY:
                raise
            if isinstance(ex, ChainedPodException):
                # the error is print from previous pod, no need to show it again
                # hence just add exception and propagate further
                # please do NOT add logger.error here!
                msg.add_exception()
            else:
                msg.add_exception(ex, executor=self._data_request_handler._executor)
                self.logger.error(
                    f'{ex!r}'
                    + f'\n add "--quiet-error" to suppress the exception details'
                    if not self.args.quiet_error
                    else '',
                    exc_info=not self.args.quiet_error,
                )

            if msg.is_data_request:
                asyncio.create_task(self._grpclet.send_message(msg))
            asyncio.create_task(self._grpclet.send_message(msg))

    def _handle(self, msg: Message) -> Message:
        """Register the current message to this pea, so that all message-related properties are up-to-date, including
        :attr:`request`, :attr:`prev_requests`, :attr:`message`, :attr:`prev_messages`. And then call the executor to handle
        this message if its envelope's  status is not ERROR, else skip handling of message.
        .. note::
            Handle does not handle explicitly message because it may wait for different messages when different parts are expected
        :param msg: received message
        :return: the transformed message.
        """

        # skip executor for non-DataRequest
        if msg.envelope.request_type != 'DataRequest':
            if msg.request.command == 'TERMINATE':
                raise RuntimeTerminated()
            self.logger.debug(f'skip executor: not data request')
            return msg

        req_id = msg.envelope.request_id
        num_expected_parts = self._get_expected_parts(msg)
        self._data_request_handler.handle(
            msg=msg,
            partial_requests=[m.request for m in self._pending_msgs[req_id]]
            if num_expected_parts > 1
            else None,
            peapod_name=self.name,
        )

        return msg

    def _get_expected_parts(self, msg):
        if msg.is_data_request:
            if not self._static_routing_table:
                graph = RoutingTable(msg.envelope.routing_table)
                return graph.active_target_pod.expected_parts
            else:
                return self.args.num_part
        else:
            return 1

    def _pre_hook(self, msg: Message) -> Message:
        """
        Pre-hook function, what to do after first receiving the message.
        :param msg: received message
        :return: `Message`
        """
        msg.add_route(self.name, self._id)

        expected_parts = self._get_expected_parts(msg)

        req_id = msg.envelope.request_id
        if expected_parts > 1:
            self._pending_msgs[req_id].append(msg)

        num_partial_requests = len(self._pending_msgs[req_id])

        if self.logger.debug_enabled:
            self._log_info_msg(
                msg,
                f'({num_partial_requests}/{expected_parts} parts)'
                if expected_parts > 1
                else '',
            )

        if expected_parts > 1 and expected_parts > num_partial_requests:
            # NOTE: reduce priority is higher than chain exception
            # otherwise a reducer will lose its function when earlier pods raise exception
            raise NoExplicitMessage

        if (
            msg.envelope.status.code == jina_pb2.StatusProto.ERROR
            and self.args.on_error_strategy >= OnErrorStrategy.SKIP_HANDLE
        ):
            raise ChainedPodException

        return msg

    def _log_info_msg(self, msg, part_str):
        info_msg = f'recv {msg.envelope.request_type} '
        req_type = msg.envelope.request_type
        if req_type == 'DataRequest':
            info_msg += (
                f'({msg.envelope.header.exec_endpoint}) - ({msg.envelope.request_id}) '
            )
        elif req_type == 'ControlRequest':
            info_msg += f'({msg.request.command}) '
        info_msg += f'{part_str} from {msg.colored_route}'
        self.logger.debug(info_msg)

    def _post_hook(self, msg: Message) -> Message:
        """
        Post-hook function, what to do before handing out the message.
        :param msg: the transformed message
        :return: `Message`
        """
        # do NOT access `msg.request.*` in the _pre_hook, as it will trigger the deserialization
        # all meta information should be stored and accessed via `msg.envelope`

        self._last_active_time = time.perf_counter()

        if self._get_expected_parts(msg) > 1:
            msgs = self._pending_msgs.pop(msg.envelope.request_id)
            msg.merge_envelope_from(msgs)

        msg.update_timestamp()
        return msg
Beispiel #12
0
 def send_status():
     return Grpclet.send_ctrl_msg(
         pod_address=f'{args.host}:{args.port_in}',
         command='STATUS')