예제 #1
0
def _task_proxy(send_conn, _bind: Callable):
    with util.socket_factory(zmq.PULL,
                             zmq.PUSH) as (zmq_ctx, proxy_in, proxy_out):
        with send_conn:
            try:
                send_conn.send_bytes(
                    serializer.dumps([_bind(proxy_in),
                                      _bind(proxy_out)]))
            except Exception:
                send_conn.send_bytes(serializer.dumps(RemoteException()))
        try:
            zmq.proxy(proxy_in, proxy_out)
        except Exception:
            util.log_internal_crash("Task proxy")
예제 #2
0
    def run(
        self,
        target: Callable = None,
        args: Sequence = None,
        kwargs: Mapping = None,
        *,
        pass_state: bool = False,
        lazy: bool = False,
    ):
        if target is None:

            def wrapper(*a, **k):
                return self.run(target, a, k, pass_state=pass_state, lazy=lazy)

            return wrapper

        task_id = util.generate_task_id()
        if args is None:
            args = ()
        if kwargs is None:
            kwargs = {}
        params = (None, None, args, None, kwargs)
        task = (target, params, pass_state, self.namespace)

        self._task_push.send_multipart(
            [util.encode_chunk_id(task_id, -1),
             serializer.dumps(task)])

        res = SimpleTaskResult(self.server_address, task_id)
        if lazy:
            return res
        return res.value
예제 #3
0
def main(server_address: str, send_conn):
    with util.socket_factory(zmq.ROUTER, zmq.ROUTER) as (
            zmq_ctx,
            state_router,
            watch_router,
    ):
        atexit.register(util.clean_process_tree)

        try:
            if server_address:
                state_router.bind(server_address)
                if "ipc" in server_address:
                    _bind = util.bind_to_random_ipc
                else:
                    _bind = util.bind_to_random_tcp
            else:
                _bind = util.bind_to_random_address
                server_address = _bind(state_router)

            server_meta = ServerMeta(__version__, server_address,
                                     _bind(watch_router),
                                     *start_task_server(_bind),
                                     *start_task_proxy(_bind))

            state_server = StateServer(state_router, watch_router, server_meta)
        except Exception:
            with send_conn:
                send_conn.send_bytes(
                    serializer.dumps(exceptions.RemoteException()))
            return
        else:
            with send_conn:
                send_conn.send_bytes(serializer.dumps(server_meta))

        while True:
            try:
                state_server.tick()
            except KeyboardInterrupt:
                util.log_internal_crash("State Server")
                return
            except Exception:
                if state_server.identity is None:
                    util.log_internal_crash("State server")
                else:
                    state_server.reply(RemoteException())
            finally:
                state_server.reset_internal_state()
예제 #4
0
def ping(server_address: str,
         *,
         timeout: float = None,
         payload: Union[bytes] = None) -> int:
    """
    Ping the zproc server.

    This can be used to easily detect if a server is alive and running, with the aid of a suitable ``timeout``.

    :param server_address:
        .. include:: /api/snippets/server_address.rst
    :param timeout:
        The timeout in seconds.

        If this is set to ``None``, then it will block forever, until the zproc server replies.

        For all other values, it will wait for a reply,
        for that amount of time before returning with a :py:class:`TimeoutError`.

        By default it is set to ``None``.
    :param payload:
        payload that will be sent to the server.

        If it is set to None, then ``os.urandom(56)`` (56 random bytes) will be used.

        (No real reason for the ``56`` magic number.)

    :return:
        The zproc server's **pid**.
    """
    if payload is None:
        payload = os.urandom(56)

    with util.create_zmq_ctx() as zmq_ctx:
        with zmq_ctx.socket(zmq.DEALER) as dealer_sock:
            dealer_sock.connect(server_address)
            if timeout is not None:
                dealer_sock.setsockopt(zmq.RCVTIMEO, int(timeout * 1000))

            dealer_sock.send(
                serializer.dumps({
                    Msgs.cmd: Cmds.ping,
                    Msgs.info: payload
                }))

            try:
                recv_payload, pid = serializer.loads(dealer_sock.recv())
            except zmq.error.Again:
                raise TimeoutError(
                    "Timed-out waiting while for the ZProc server to respond.")

            assert (
                recv_payload == payload
            ), "Payload doesn't match! The server connection may be compromised, or unstable."

            return pid
예제 #5
0
def _task_server(send_conn, _bind: Callable):
    with util.socket_factory(zmq.ROUTER,
                             zmq.PULL) as (zmq_ctx, router, result_pull):
        with send_conn:
            try:
                send_conn.send_bytes(
                    serializer.dumps([_bind(router),
                                      _bind(result_pull)]))
                server = TaskResultServer(router, result_pull)
            except Exception:
                send_conn.send_bytes(serializer.dumps(RemoteException()))
                return
        while True:
            try:
                server.tick()
            except KeyboardInterrupt:
                util.log_internal_crash("Task server")
                return
            except Exception:
                util.log_internal_crash("Task proxy")
예제 #6
0
def worker_process(server_address: str, send_conn):
    with util.socket_factory(zmq.PULL,
                             zmq.PUSH) as (zmq_ctx, task_pull, result_push):
        server_meta = util.get_server_meta(zmq_ctx, server_address)

        try:
            task_pull.connect(server_meta.task_proxy_out)
            result_push.connect(server_meta.task_result_pull)
            state = State(server_address)
        except Exception:
            with send_conn:
                send_conn.send_bytes(serializer.dumps(RemoteException()))
        else:
            with send_conn:
                send_conn.send_bytes(b"")

        try:
            while True:
                msg = task_pull.recv_multipart()
                if msg == EMPTY_MULTIPART:
                    return
                chunk_id, target_bytes, task_bytes = msg

                try:
                    task = serializer.loads(task_bytes)
                    target = serializer.loads_fn(target_bytes)

                    result = run_task(target, task, state)
                except KeyboardInterrupt:
                    raise
                except Exception:
                    result = RemoteException()
                result_push.send_multipart(
                    [chunk_id, serializer.dumps(result)])
        except Exception:
            util.log_internal_crash("Worker process")
예제 #7
0
 def recv_request(self):
     ident, chunk_id = self.router.recv_multipart()
     try:
         task_id, index = util.decode_chunk_id(chunk_id)
         # print("request->", task_id, index)
         task_store = self.result_store[task_id]
         try:
             chunk_result = task_store[index]
         except KeyError:
             self.pending[chunk_id].appendleft(ident)
         else:
             self.router.send_multipart([ident, chunk_result])
     except KeyboardInterrupt:
         raise
     except Exception:
         self.router.send_multipart(
             [ident, serializer.dumps(RemoteException())])
예제 #8
0
    def mutate_safely(self):
        old = deepcopy(self.state)
        stamp = time.time()

        try:
            yield
        except Exception:
            self.state = self.state_map[self.namespace] = old
            raise

        slot = self.history[self.namespace]
        slot[0].append(stamp)
        slot[1].append([
            self.identity,
            serializer.dumps((old, self.state, stamp)),
            self.state == old,
        ])
        self.resolve_pending()
예제 #9
0
    def main(self):
        @wraps(self.target)
        def target_wrapper(*args, **kwargs):
            while True:
                self.retries += 1
                try:
                    return self.target(*args, **kwargs)
                except exceptions.ProcessExit as e:
                    self.exitcode = e.exitcode
                    return None
                except self.to_catch as e:
                    self._handle_exc(e, handle_retry=True)

                    if self.retry_args is not None:
                        self.target_args = self.retry_args
                    if self.retry_kwargs is not None:
                        self.target_kwargs = self.retry_kwargs

        try:
            if self.pass_context:
                from .context import Context  # this helps avoid a circular import

                return_value = target_wrapper(
                    Context(
                        self.kwargs["server_address"],
                        namespace=self.kwargs["namespace"],
                        start_server=False,
                    ), *self.target_args, **self.target_kwargs)
            else:
                return_value = target_wrapper(*self.target_args,
                                              **self.target_kwargs)
            # print(return_value)
            with util.create_zmq_ctx(linger=True) as zmq_ctx:
                with zmq_ctx.socket(zmq.PAIR) as result_sock:
                    result_sock.connect(self.kwargs["result_address"])
                    result_sock.send(serializer.dumps(return_value))
        except Exception as e:
            self._handle_exc(e)
        finally:
            util.clean_process_tree(self.exitcode)
예제 #10
0
 def _s_request_reply(self, request: Dict[int, Any]):
     request[Msgs.namespace] = self._namespace_bytes
     msg = serializer.dumps(request)
     return serializer.loads(
         util.strict_request_reply(msg, self._s_dealer.send,
                                   self._s_dealer.recv))
예제 #11
0
 def reply(self, response):
     # print("server rep:", self.identity, response, time.time())
     self.state_router.send_multipart(
         [self.identity, serializer.dumps(response)])
예제 #12
0
if not IPC_BASE_DIR.exists():
    IPC_BASE_DIR.mkdir(parents=True)


def create_ipc_address(name: str) -> str:
    return "ipc://" + str(IPC_BASE_DIR / name)


def get_server_meta(zmq_ctx: zmq.Context, server_address: str) -> ServerMeta:
    with zmq_ctx.socket(zmq.DEALER) as dealer:
        dealer.connect(server_address)
        return req_server_meta(dealer)


_server_meta_req_cache = serializer.dumps({
    Msgs.cmd: Cmds.get_server_meta,
    Msgs.namespace: DEFAULT_NAMESPACE
})


def req_server_meta(dealer: zmq.Socket) -> ServerMeta:
    dealer.send(_server_meta_req_cache)
    server_meta = serializer.loads(dealer.recv())
    if server_meta.version != __version__:
        raise RuntimeError(
            "The server version didn't match. "
            "Please make sure the server (%r) is using the same version of ZProc as this client (%r)."
            % (server_meta.version, __version__))
    return server_meta


def to_catchable_exc(
예제 #13
0
    def map_lazy(
        self,
        target: Callable,
        map_iter: Sequence[Any] = None,
        *,
        map_args: Sequence[Sequence[Any]] = None,
        args: Sequence = None,
        map_kwargs: Sequence[Mapping[str, Any]] = None,
        kwargs: Mapping = None,
        pass_state: bool = False,
        num_chunks: int = None,
    ) -> SequenceTaskResult:
        r"""
        Functional equivalent of ``map()`` in-built function,
        but executed in a parallel fashion.

        Distributes the iterables,
        provided in the ``map_*`` arguments to ``num_chunks`` no of worker nodes.

        The idea is to:
            1. Split the the iterables provided in the ``map_*`` arguments into ``num_chunks`` no of equally sized chunks.
            2. Send these chunks to ``num_chunks`` number of worker nodes.
            3. Wait for all these worker nodes to finish their task(s).
            4. Combine the acquired results in the same sequence as provided in the ``map_*`` arguments.
            5. Return the combined results.

            *Steps 3-5 can be done lazily, on the fly with the help of an iterator*

        :param target:
            The ``Callable`` to be invoked inside a :py:class:`Process`.

            *It is invoked with the following signature:*

                ``target(map_iter[i], *map_args[i], *args, **map_kwargs[i], **kwargs)``

            *Where:*

                - ``i`` is the index of n\ :sup:`th` element of the Iterable(s) provided in the ``map_*`` arguments.

                - ``args`` and ``kwargs`` are passed from the ``**process_kwargs``.

            The ``pass_state`` Keyword Argument of allows you to include the ``state`` arg.
        :param map_iter:
            A sequence whose elements are supplied as the *first* positional argument to the ``target``.
        :param map_args:
            A sequence whose elements are supplied as positional arguments (``*args``) to the ``target``.
        :param map_kwargs:
            A sequence whose elements are supplied as keyword arguments (``**kwargs``) to the ``target``.
        :param args:
            The argument tuple for ``target``, supplied after ``map_iter`` and ``map_args``.

            By default, it is an empty ``tuple``.
        :param kwargs:
            A dictionary of keyword arguments for ``target``.

            By default, it is an empty ``dict``.
        :param pass_state:
            Weather this process needs to access the state.

            If this is set to ``False``,
            then the ``state`` argument won't be provided to the ``target``.

            If this is set to ``True``,
            then a :py:class:`State` object is provided as the first Argument to the ``target``.

            Unlike :py:class:`Process` it is set to ``False`` by default.
            (To retain a similar API to in-built ``map()``)
        :param num_chunks:
            The number of worker nodes to use.

            By default, it is set to ``multiprocessing.cpu_count()``
            (The number of CPU cores on your system)
        :param lazy:
            Wheteher to return immediately put
        :return:
            The result is quite similar to ``map()`` in-built function.

            It returns a :py:class:`Iterable` which contatins,
            the return values of the ``target`` function,
            when applied to every item of the Iterables provided in the ``map_*`` arguments.

            The actual "processing" starts as soon as you call this function.

            The returned :py:class:`Iterable` only fetches the results from the worker processes.

        .. note::
            - If ``len(map_iter) != len(maps_args) != len(map_kwargs)``,
              then the results will be cut-off at the shortest Sequence.

        See :ref:`worker_map` for Examples.
        """
        if num_chunks is None:
            num_chunks = multiprocessing.cpu_count()

        lengths = [
            len(i) for i in (map_iter, map_args, map_kwargs) if i is not None
        ]
        assert (
            lengths
        ), "At least one of `map_iter`, `map_args`, or `map_kwargs` must be provided as a non-empty Sequence."

        length = min(lengths)

        assert (length > num_chunks
                ), "`length`(%d) cannot be less than `num_chunks`(%d)" % (
                    length, num_chunks)

        chunk_length, extra = divmod(length, num_chunks)
        if extra:
            chunk_length += 1
        task_id = util.generate_task_id((chunk_length, length, num_chunks))

        iter_chunks = util.make_chunks(map_iter, chunk_length, num_chunks)
        args_chunks = util.make_chunks(map_args, chunk_length, num_chunks)
        kwargs_chunks = util.make_chunks(map_kwargs, chunk_length, num_chunks)

        target_bytes = serializer.dumps_fn(target)

        for index in range(num_chunks):
            params = (
                iter_chunks[index],
                args_chunks[index],
                args,
                kwargs_chunks[index],
                kwargs,
            )
            task = (params, pass_state, self.namespace)

            self._task_push.send_multipart([
                util.encode_chunk_id(task_id, index),
                target_bytes,
                serializer.dumps(task),
            ])

        return SequenceTaskResult(self.server_address, task_id)