async def test_exit_callback():
    to_child = mp_context.Queue()
    from_child = mp_context.Queue()
    evt = Event()

    # FIXME: this breaks if changed to async def...
    @gen.coroutine
    def on_stop(_proc):
        assert _proc is proc
        yield gen.moment
        evt.set()

    # Normal process exit
    proc = AsyncProcess(target=feed, args=(to_child, from_child))
    evt.clear()
    proc.set_exit_callback(on_stop)
    proc.daemon = True

    await proc.start()
    await asyncio.sleep(0.05)
    assert proc.is_alive()
    assert not evt.is_set()

    to_child.put(None)
    await evt.wait(timedelta(seconds=3))
    assert evt.is_set()
    assert not proc.is_alive()

    # Process terminated
    proc = AsyncProcess(target=wait)
    evt.clear()
    proc.set_exit_callback(on_stop)
    proc.daemon = True

    await proc.start()
    await asyncio.sleep(0.05)
    assert proc.is_alive()
    assert not evt.is_set()

    await proc.terminate()
    await evt.wait(timedelta(seconds=3))
    assert evt.is_set()
def test_exit_callback():
    to_child = mp_context.Queue()
    from_child = mp_context.Queue()
    evt = Event()

    @gen.coroutine
    def on_stop(_proc):
        assert _proc is proc
        yield gen.moment
        evt.set()

    # Normal process exit
    proc = AsyncProcess(target=feed, args=(to_child, from_child))
    evt.clear()
    proc.set_exit_callback(on_stop)
    proc.daemon = True

    yield proc.start()
    yield gen.sleep(0.05)
    assert proc.is_alive()
    assert not evt.is_set()

    to_child.put(None)
    yield evt.wait(timedelta(seconds=3))
    assert evt.is_set()
    assert not proc.is_alive()

    # Process terminated
    proc = AsyncProcess(target=wait)
    evt.clear()
    proc.set_exit_callback(on_stop)
    proc.daemon = True

    yield proc.start()
    yield gen.sleep(0.05)
    assert proc.is_alive()
    assert not evt.is_set()

    yield proc.terminate()
    yield evt.wait(timedelta(seconds=3))
    assert evt.is_set()
Exemple #3
0
class WorkerProcess:
    running: asyncio.Event
    stopped: asyncio.Event

    # The interval how often to check the msg queue for init
    _init_msg_interval = 0.05

    def __init__(
        self,
        worker_kwargs,
        worker_start_args,
        silence_logs,
        on_exit,
        worker,
        env,
        config,
    ):
        self.status = Status.init
        self.silence_logs = silence_logs
        self.worker_kwargs = worker_kwargs
        self.worker_start_args = worker_start_args
        self.on_exit = on_exit
        self.process = None
        self.Worker = worker
        self.env = env
        self.config = config

        # Initialized when worker is ready
        self.worker_dir = None
        self.worker_address = None

    async def start(self) -> Status:
        """
        Ensure the worker process is started.
        """
        enable_proctitle_on_children()
        if self.status == Status.running:
            return self.status
        if self.status == Status.starting:
            await self.running.wait()
            return self.status

        self.init_result_q = init_q = mp_context.Queue()
        self.child_stop_q = mp_context.Queue()
        uid = uuid.uuid4().hex

        self.process = AsyncProcess(
            target=self._run,
            name="Dask Worker process (from Nanny)",
            kwargs=dict(
                worker_kwargs=self.worker_kwargs,
                worker_start_args=self.worker_start_args,
                silence_logs=self.silence_logs,
                init_result_q=self.init_result_q,
                child_stop_q=self.child_stop_q,
                uid=uid,
                Worker=self.Worker,
                env=self.env,
                config=self.config,
            ),
        )
        self.process.daemon = dask.config.get("distributed.worker.daemon",
                                              default=True)
        self.process.set_exit_callback(self._on_exit)
        self.running = asyncio.Event()
        self.stopped = asyncio.Event()
        self.status = Status.starting

        try:
            await self.process.start()
        except OSError:
            logger.exception("Nanny failed to start process", exc_info=True)
            self.process.terminate()
            self.status = Status.failed
            return self.status
        try:
            msg = await self._wait_until_connected(uid)
        except Exception:
            self.status = Status.failed
            self.process.terminate()
            raise
        if not msg:
            return self.status
        self.worker_address = msg["address"]
        self.worker_dir = msg["dir"]
        assert self.worker_address
        self.status = Status.running
        self.running.set()

        init_q.close()

        return self.status

    def _on_exit(self, proc):
        if proc is not self.process:
            # Ignore exit of old process instance
            return
        self.mark_stopped()

    def _death_message(self, pid, exitcode):
        assert exitcode is not None
        if exitcode == 255:
            return "Worker process %d was killed by unknown signal" % (pid, )
        elif exitcode >= 0:
            return "Worker process %d exited with status %d" % (pid, exitcode)
        else:
            return "Worker process %d was killed by signal %d" % (pid,
                                                                  -exitcode)

    def is_alive(self):
        return self.process is not None and self.process.is_alive()

    @property
    def pid(self):
        return self.process.pid if self.process and self.process.is_alive(
        ) else None

    def mark_stopped(self):
        if self.status != Status.stopped:
            r = self.process.exitcode
            assert r is not None
            if r != 0:
                msg = self._death_message(self.process.pid, r)
                logger.info(msg)
            self.status = Status.stopped
            self.stopped.set()
            # Release resources
            self.process.close()
            self.init_result_q = None
            self.child_stop_q = None
            self.process = None
            # Best effort to clean up worker directory
            if self.worker_dir and os.path.exists(self.worker_dir):
                shutil.rmtree(self.worker_dir, ignore_errors=True)
            self.worker_dir = None
            # User hook
            if self.on_exit is not None:
                self.on_exit(r)

    async def kill(self, timeout: float = 2, executor_wait: bool = True):
        """
        Ensure the worker process is stopped, waiting at most
        *timeout* seconds before terminating it abruptly.
        """
        deadline = time() + timeout

        if self.status == Status.stopped:
            return
        if self.status == Status.stopping:
            await self.stopped.wait()
            return
        assert self.status in (Status.starting, Status.running)
        self.status = Status.stopping

        process = self.process
        self.child_stop_q.put({
            "op": "stop",
            "timeout": max(0, deadline - time()) * 0.8,
            "executor_wait": executor_wait,
        })
        await asyncio.sleep(0)  # otherwise we get broken pipe errors
        self.child_stop_q.close()

        while process.is_alive() and time() < deadline:
            await asyncio.sleep(0.05)

        if process.is_alive():
            logger.warning(
                f"Worker process still alive after {timeout} seconds, killing")
            try:
                await process.terminate()
            except Exception as e:
                logger.error("Failed to kill worker process: %s", e)

    async def _wait_until_connected(self, uid):
        while True:
            if self.status != Status.starting:
                return
            # This is a multiprocessing queue and we'd block the event loop if
            # we simply called get
            try:
                msg = self.init_result_q.get_nowait()
            except Empty:
                await asyncio.sleep(self._init_msg_interval)
                continue

            if msg["uid"] != uid:  # ensure that we didn't cross queues
                continue

            if "exception" in msg:
                logger.error("Failed while trying to start worker process: %s",
                             msg["exception"])
                raise msg["exception"]
            else:
                return msg

    @classmethod
    def _run(
        cls,
        worker_kwargs,
        worker_start_args,
        silence_logs,
        init_result_q,
        child_stop_q,
        uid,
        env,
        config,
        Worker,
    ):  # pragma: no cover
        try:
            os.environ.update(env)
            dask.config.set(config)
            try:
                from dask.multiprocessing import initialize_worker_process
            except ImportError:  # old Dask version
                pass
            else:
                initialize_worker_process()

            if silence_logs:
                logger.setLevel(silence_logs)

            IOLoop.clear_instance()
            loop = IOLoop()
            loop.make_current()
            worker = Worker(**worker_kwargs)

            async def do_stop(timeout=5, executor_wait=True):
                try:
                    await worker.close(
                        report=True,
                        nanny=False,
                        safe=True,  # TODO: Graceful or not?
                        executor_wait=executor_wait,
                        timeout=timeout,
                    )
                finally:
                    loop.stop()

            def watch_stop_q():
                """
                Wait for an incoming stop message and then stop the
                worker cleanly.
                """
                msg = child_stop_q.get()
                child_stop_q.close()
                assert msg.pop("op") == "stop"
                loop.add_callback(do_stop, **msg)

            t = threading.Thread(target=watch_stop_q,
                                 name="Nanny stop queue watch")
            t.daemon = True
            t.start()

            async def run():
                """
                Try to start worker and inform parent of outcome.
                """
                try:
                    await worker
                except Exception as e:
                    logger.exception("Failed to start worker")
                    init_result_q.put({"uid": uid, "exception": e})
                    init_result_q.close()
                    # If we hit an exception here we need to wait for a least
                    # one interval for the outside to pick up this message.
                    # Otherwise we arrive in a race condition where the process
                    # cleanup wipes the queue before the exception can be
                    # properly handled. See also
                    # WorkerProcess._wait_until_connected (the 2 is for good
                    # measure)
                    sync_sleep(cls._init_msg_interval * 2)
                else:
                    try:
                        assert worker.address
                    except ValueError:
                        pass
                    else:
                        init_result_q.put({
                            "address": worker.address,
                            "dir": worker.local_directory,
                            "uid": uid,
                        })
                        init_result_q.close()
                        await worker.finished()
                        logger.info("Worker closed")

        except Exception as e:
            logger.exception("Failed to initialize Worker")
            init_result_q.put({"uid": uid, "exception": e})
            init_result_q.close()
            # If we hit an exception here we need to wait for a least one
            # interval for the outside to pick up this message. Otherwise we
            # arrive in a race condition where the process cleanup wipes the
            # queue before the exception can be properly handled. See also
            # WorkerProcess._wait_until_connected (the 2 is for good measure)
            sync_sleep(cls._init_msg_interval * 2)
        else:
            try:
                loop.run_sync(run)
            except (TimeoutError, gen.TimeoutError):
                # Loop was stopped before wait_until_closed() returned, ignore
                pass
            except KeyboardInterrupt:
                # At this point the loop is not running thus we have to run
                # do_stop() explicitly.
                loop.run_sync(do_stop)
Exemple #4
0
class Scheduler(ProcessInterface):
    def __init__(self, env=None, *args, **kwargs):
        super().__init__()

        self.args = args
        self.kwargs = kwargs
        self.proc_cls = _Scheduler
        self.process = None
        self.env = env or {}

    def __repr__(self):
        self.child_info_stop_q.put({"op": "info"})
        try:
            msg = self.parent_info_q.get(timeout=3000)
        except Empty:
            pass
        else:
            assert msg.pop("op") == "info"
            return "<Scheduler: '%s' processes: %d cores: %d>" % (
                self.address,
                msg.pop("workers"),
                msg.pop("total_nthreads"),
            )

    async def _wait_until_started(self):
        delay = 0.05
        while True:
            if self.status != "starting":
                return
            try:
                msg = self.init_result_q.get_nowait()
            except Empty:
                await gen.sleep(delay)
                continue

            if "exception" in msg:
                logger.error(
                    "Failed while trying to start scheduler process: %s",
                    msg["exception"],
                )
                await self.process.join()
                raise msg
            else:
                return msg

    async def start(self):
        if self.status == "running":
            return self.status
        if self.status == "starting":
            await self.running.wait()
            return self.status

        self.init_result_q = init_q = mp_context.Queue()
        self.child_info_stop_q = mp_context.Queue()
        self.parent_info_q = mp_context.Queue()

        self.process = AsyncProcess(
            target=self._run,
            name="Dask CUDA Scheduler process",
            kwargs=dict(
                proc_cls=self.proc_cls,
                kwargs=self.kwargs,
                silence_logs=False,
                init_result_q=self.init_result_q,
                child_info_stop_q=self.child_info_stop_q,
                parent_info_q=self.parent_info_q,
                env=self.env,
            ),
        )
        # self.process.daemon = dask.config.get("distributed.worker.daemon", default=True)
        self.process.set_exit_callback(self._on_exit)
        self.running = Event()
        self.stopped = Event()
        self.status = "starting"
        try:
            await self.process.start()
        except OSError:
            logger.exception("Failed to start CUDA Scheduler process",
                             exc_info=True)
            self.process.terminate()
            return

        msg = await self._wait_until_started()
        if not msg:
            return self.status
        self.address = msg["address"]
        assert self.address
        self.status = "running"
        self.running.set()

        init_q.close()

        await super().start()

    def _on_exit(self, proc):
        if proc is not self.process:
            return
        self.mark_stopped()

    def _death_message(self, pid, exitcode):
        assert exitcode is not None
        if exitcode == 255:
            return "Scheduler process %d was killed by unknown signal" % (
                pid, )
        elif exitcode >= 0:
            return "Scheduler process %d exited with status %d" % (pid,
                                                                   exitcode)
        else:
            return "Scheduler process %d was killed by signal %d" % (pid,
                                                                     -exitcode)

    def mark_stopped(self):
        if self.status != "stopped":
            r = self.process.exitcode
            assert r is not None
            if r != 0:
                msg = self._death_message(self.process.pid, r)
                logger.info(msg)
            self.status = "stopped"
            self.stopped.set()
            # Release resources
            self.process.close()
            self.init_result_q = None
            self.child_info_stop_q = None
            self.parent_info_q = None
            self.process = None

    async def close(self):
        timeout = 2
        loop = IOLoop.current()
        deadline = loop.time() + timeout
        if self.status == "closing":
            await self.finished()
            assert self.status == "closed"

        if self.status == "closed":
            return

        try:
            if self.process is not None:
                #await self.kill()
                process = self.process
                self.child_info_stop_q.put({
                    "op":
                    "stop",
                    "timeout":
                    max(0, deadline - loop.time()) * 0.8,
                })
                self.child_info_stop_q.close()
                self.parent_info_q.close()

                while process.is_alive() and loop.time() < deadline:
                    await gen.sleep(0.05)

                if process.is_alive():
                    logger.warning(
                        "Scheduler process still alive after %d seconds, killing",
                        timeout)
                    try:
                        await process.terminate()
                    except Exception as e:
                        logger.error("Failed to kill scheduler process: %s", e)
        except Exception:
            pass
        self.process = None
        self.status = "closed"
        await super().close()

    @classmethod
    def _run(
        cls,
        silence_logs,
        init_result_q,
        child_info_stop_q,
        parent_info_q,
        proc_cls,
        kwargs,
        env,
    ):  # pragma: no cover
        os.environ.update(env)

        if silence_logs:
            logger.setLevel(silence_logs)

        IOLoop.clear_instance()
        loop = IOLoop()
        loop.make_current()
        scheduler = proc_cls(**kwargs)

        async def do_stop(timeout=5):
            try:
                await scheduler.close(comm=None,
                                      fast=False,
                                      close_workers=False)
            finally:
                loop.stop()

        def watch_stop_q():
            """
            Wait for an incoming stop message and then stop the
            scheduler cleanly.
            """
            while True:
                try:
                    msg = child_info_stop_q.get(timeout=1000)
                except Empty:
                    pass
                else:
                    op = msg.pop("op")
                    assert op == "stop" or op == "info"
                    if op == "stop":
                        child_info_stop_q.close()
                        loop.add_callback(do_stop, **msg)
                        break
                    elif op == "info":
                        parent_info_q.put({
                            "op":
                            "info",
                            "workers":
                            len(scheduler.workers),
                            "total_nthreads":
                            scheduler.total_nthreads,
                        })

        t = threading.Thread(target=watch_stop_q,
                             name="Scheduler stop queue watch")
        t.daemon = True
        t.start()

        async def run():
            """
            Try to start scheduler and inform parent of outcome.
            """
            try:
                await scheduler.start()
            except Exception as e:
                logger.exception("Failed to start scheduler")
                init_result_q.put({"exception": e})
                init_result_q.close()
            else:
                try:
                    assert scheduler.address
                except ValueError:
                    pass
                else:
                    init_result_q.put({"address": scheduler.address})
                    init_result_q.close()
                    await scheduler.finished()
                    logger.info("Scheduler closed")

        try:
            loop.run_sync(run)
        except TimeoutError:
            # Loop was stopped before wait_until_closed() returned, ignore
            pass
        except KeyboardInterrupt:
            pass