예제 #1
0
def test_basic(loop, worker_class, mpirun):
    try:
        import_term(worker_class)
    except (ImportError, AttributeError):
        pytest.skip("Cannot import {}, perhaps it is not installed".format(
            worker_class))
    with tmpfile(extension="json") as fn:

        cmd = mpirun + [
            "-np",
            "4",
            "dask-mpi",
            "--scheduler-file",
            fn,
            "--worker-class",
            worker_class,
        ]

        with popen(cmd):
            with Client(scheduler_file=fn) as c:
                start = time()
                while len(c.scheduler_info()["workers"]) < 3:
                    assert time() < start + 10
                    sleep(0.2)

                assert c.submit(lambda x: x + 1, 10).result() == 11
예제 #2
0
async def run_spec(spec: dict, *args):
    workers = {}
    for k, d in spec.items():
        cls = d["cls"]
        if isinstance(cls, str):
            cls = import_term(cls)
        workers[k] = cls(*args, **d.get("opts", {}))

    if workers:
        await asyncio.gather(*workers.values())
        for w in workers.values():
            await w  # for tornado gen.coroutine support
    return workers
예제 #3
0
    async def _start(self):
        async with self._lock:
            while self.status == "starting":
                await asyncio.sleep(0.01)
            if self.status == "running":
                return
            if self.status == "closed":
                raise ValueError("Cluster is closed")
            self.status = "starting"

        if self.local:
            if self.scheduler_spec is None:
                raise ValueError("No scheduler was set")

            scheduler_clz = self.scheduler_spec["cls"]
            if isinstance(scheduler_clz, str):
                scheduler_clz = import_term(scheduler_clz)
            self.scheduler = scheduler_clz(
                **self.scheduler_spec.get("options", {}))
            self.scheduler = await self.scheduler
            self.scheduler_comm = rpc(
                getattr(self.scheduler, "external_address", None)
                or self.scheduler.address,
                connection_args=self.security.get_connection_args("client"),
            )
            with open(self.scheduler_file, 'w+') as f:
                json.dump({'address': self.scheduler.address}, f)
            await self._spawn_workers()
        else:
            await self._spawn_workers()
            counter = 1
            while not os.path.exists(self.scheduler_file):
                await asyncio.sleep(5.0)
                counter += 5
                if self.worker.returncode is not None or counter > int(
                        self.timeout):
                    _, err = await self.worker.communicate()
                    self.worker.terminate()
                    raise ValueError(
                        "Scheduler failed to spawn. The request for a job returned: "
                        + err)
            with open(self.scheduler_file, "r") as scheduler_file:
                obj = json.load(scheduler_file)
                self.scheduler_comm = rpc(obj["address"])

        await super()._start()
예제 #4
0
파일: core.py 프로젝트: dask/dask-mpi
 async def run_worker():
     WorkerType = import_term(worker_class)
     if nanny:
         raise DeprecationWarning(
             "Option nanny=True is deprectaed, use worker_class='distributed.Nanny' instead"
         )
         WorkerType = Nanny
     opts = {
         "interface": interface,
         "protocol": protocol,
         "nthreads": nthreads,
         "memory_limit": memory_limit,
         "local_directory": local_directory,
         "name": rank,
         **worker_options,
     }
     async with WorkerType(**opts) as worker:
         await worker.finished()
예제 #5
0
    async def _correct_state_internal(self):
        async with self._lock:
            self._correct_state_waiting = None

            to_close = set(self.workers) - set(self.worker_spec)
            if to_close:
                if self.scheduler.status == Status.running:
                    await self.scheduler_comm.retire_workers(
                        workers=list(to_close))
                tasks = [
                    asyncio.create_task(self.workers[w].close())
                    for w in to_close if w in self.workers
                ]
                await asyncio.wait(tasks)
                for task in tasks:  # for tornado gen.coroutine support
                    with suppress(RuntimeError):
                        await task
            for name in to_close:
                if name in self.workers:
                    del self.workers[name]

            to_open = set(self.worker_spec) - set(self.workers)
            workers = []
            for name in to_open:
                d = self.worker_spec[name]
                cls, opts = d["cls"], d.get("options", {})
                if "name" not in opts:
                    opts = opts.copy()
                    opts["name"] = name
                if isinstance(cls, str):
                    cls = import_term(cls)
                worker = cls(self.scheduler.address, **opts)
                self._created.add(worker)
                workers.append(worker)
            if workers:
                await asyncio.wait(workers)
                for w in workers:
                    w._cluster = weakref.ref(self)
                    await w  # for tornado gen.coroutine support
            self.workers.update(dict(zip(to_open, workers)))
예제 #6
0
    def __init__(
        self,
        scheduler: Scheduler,
        # The following parameters are exposed so that one may create, run, and throw
        # away on the fly a specialized manager, separate from the main one.
        policies: set[ActiveMemoryManagerPolicy] | None = None,
        *,
        register: bool = True,
        start: bool | None = None,
        interval: float | None = None,
    ):
        self.scheduler = scheduler
        self.policies = set()

        if policies is None:
            # Initialize policies from config
            policies = set()
            for kwargs in dask.config.get(
                    "distributed.scheduler.active-memory-manager.policies"):
                kwargs = kwargs.copy()
                cls = import_term(kwargs.pop("class"))
                policies.add(cls(**kwargs))

        for policy in policies:
            self.add_policy(policy)

        if register:
            scheduler.extensions["amm"] = self
            scheduler.handlers["amm_handler"] = self.amm_handler

        if interval is None:
            interval = parse_timedelta(
                dask.config.get(
                    "distributed.scheduler.active-memory-manager.interval"))
        self.interval = interval
        if start is None:
            start = dask.config.get(
                "distributed.scheduler.active-memory-manager.start")
        if start:
            self.start()
예제 #7
0
        async def run_worker():

            WorkerType = import_term(worker_class)
            if not nanny:
                raise DeprecationWarning(
                    "Option --no-nanny is deprectaed, use --worker-class instead"
                )
                WorkerType = Worker
            opts = {
                "interface": interface,
                "protocol": protocol,
                "nthreads": nthreads,
                "memory_limit": memory_limit,
                "local_directory": local_directory,
                "name": f"{name}-{rank}",
                "scheduler_file": scheduler_file,
                **worker_options,
            }
            if scheduler_address:
                opts["scheduler_ip"] = scheduler_address
            async with WorkerType(**opts) as worker:
                await worker.finished()
예제 #8
0
    async def _start(self):
        while self.status == Status.starting:
            await asyncio.sleep(0.01)
        if self.status == Status.running:
            return
        if self.status == Status.closed:
            raise ValueError("Cluster is closed")

        self._lock = asyncio.Lock()
        self.status = Status.starting

        if self.scheduler_spec is None:
            try:
                import distributed.dashboard  # noqa: F401
            except ImportError:
                pass
            else:
                options = {"dashboard": True}
            self.scheduler_spec = {"cls": Scheduler, "options": options}

        # Check if scheduler has already been created by a subclass
        if self.scheduler is None:
            cls = self.scheduler_spec["cls"]
            if isinstance(cls, str):
                cls = import_term(cls)
            self.scheduler = cls(**self.scheduler_spec.get("options", {}))
            self.scheduler = await self.scheduler
        self.scheduler_comm = rpc(
            getattr(self.scheduler, "external_address", None)
            or self.scheduler.address,
            connection_args=self.security.get_connection_args("client"),
        )
        try:
            await super()._start()
        except Exception as e:  # pragma: no cover
            self.status = Status.failed
            await self._close()
            raise RuntimeError(f"Cluster failed to start: {e}") from e
예제 #9
0
    def __call__(self, *args, **kwargs):
        from distributed.utils import import_term

        return import_term(self.function_path)(*args, **kwargs)
예제 #10
0
def main(scheduler, host, worker_port, listen_address, contact_address,
         nanny_port, nthreads, nprocs, nanny, name, pid_file, resources,
         dashboard, bokeh, bokeh_port, scheduler_file, dashboard_prefix,
         tls_ca_file, tls_cert, tls_key, dashboard_address, worker_class,
         preload_nanny, **kwargs):
    g0, g1, g2 = gc.get_threshold(
    )  # https://github.com/dask/distributed/issues/1653
    gc.set_threshold(g0 * 3, g1 * 3, g2 * 3)

    enable_proctitle_on_current()
    enable_proctitle_on_children()

    if bokeh_port is not None:
        warnings.warn(
            "The --bokeh-port flag has been renamed to --dashboard-address. "
            "Consider adding ``--dashboard-address :%d`` " % bokeh_port)
        dashboard_address = bokeh_port
    if bokeh is not None:
        warnings.warn(
            "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. "
        )
        dashboard = bokeh

    sec = {
        k: v
        for k, v in [
            ("tls_ca_file", tls_ca_file),
            ("tls_worker_cert", tls_cert),
            ("tls_worker_key", tls_key),
        ] if v is not None
    }

    if nprocs < 0:
        nprocs = CPU_COUNT + 1 + nprocs

    if nprocs <= 0:
        logger.error(
            "Failed to launch worker. Must specify --nprocs so that there's at least one process."
        )
        sys.exit(1)

    if nprocs > 1 and not nanny:
        logger.error(
            "Failed to launch worker.  You cannot use the --no-nanny argument when nprocs > 1."
        )
        sys.exit(1)

    if contact_address and not listen_address:
        logger.error(
            "Failed to launch worker. "
            "Must specify --listen-address when --contact-address is given")
        sys.exit(1)

    if nprocs > 1 and listen_address:
        logger.error("Failed to launch worker. "
                     "You cannot specify --listen-address when nprocs > 1.")
        sys.exit(1)

    if (worker_port or host) and listen_address:
        logger.error(
            "Failed to launch worker. "
            "You cannot specify --listen-address when --worker-port or --host is given."
        )
        sys.exit(1)

    try:
        if listen_address:
            (host, worker_port) = get_address_host_port(listen_address,
                                                        strict=True)

        if contact_address:
            # we only need this to verify it is getting parsed
            (_, _) = get_address_host_port(contact_address, strict=True)
        else:
            # if contact address is not present we use the listen_address for contact
            contact_address = listen_address
    except ValueError as e:
        logger.error("Failed to launch worker. " + str(e))
        sys.exit(1)

    if nanny:
        port = nanny_port
    else:
        port = worker_port

    if not nthreads:
        nthreads = CPU_COUNT // nprocs

    if pid_file:
        with open(pid_file, "w") as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    if resources:
        resources = resources.replace(",", " ").split()
        resources = dict(pair.split("=") for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    worker_class = import_term(worker_class)
    if nanny:
        kwargs["worker_class"] = worker_class
        kwargs["preload_nanny"] = preload_nanny

    if nanny:
        kwargs.update({
            "worker_port": worker_port,
            "listen_address": listen_address
        })
        t = Nanny
    else:
        if nanny_port:
            kwargs["service_ports"] = {"nanny": nanny_port}
        t = worker_class

    if (not scheduler and not scheduler_file
            and dask.config.get("scheduler-address", None) is None):
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    with suppress(TypeError, ValueError):
        name = int(name)

    if "DASK_INTERNAL_INHERIT_CONFIG" in os.environ:
        config = deserialize_for_cli(
            os.environ["DASK_INTERNAL_INHERIT_CONFIG"])
        # Update the global config given priority to the existing global config
        dask.config.update(dask.config.global_config, config, priority="old")

    nannies = [
        t(scheduler,
          scheduler_file=scheduler_file,
          nthreads=nthreads,
          loop=loop,
          resources=resources,
          security=sec,
          contact_address=contact_address,
          host=host,
          port=port,
          dashboard=dashboard,
          dashboard_address=dashboard_address,
          name=name if nprocs == 1 or name is None or name == "" else
          str(name) + "-" + str(i),
          **kwargs) for i in range(nprocs)
    ]

    async def close_all():
        # Unregister all workers from scheduler
        if nanny:
            await asyncio.gather(*[n.close(timeout=2) for n in nannies])

    signal_fired = False

    def on_signal(signum):
        nonlocal signal_fired
        signal_fired = True
        if signum != signal.SIGINT:
            logger.info("Exiting on signal %d", signum)
        return asyncio.ensure_future(close_all())

    async def run():
        await asyncio.gather(*nannies)
        await asyncio.gather(*[n.finished() for n in nannies])

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except TimeoutError:
        # We already log the exception in nanny / worker. Don't do it again.
        if not signal_fired:
            logger.info("Timed out starting worker")
        sys.exit(1)
    except KeyboardInterrupt:
        pass
    finally:
        logger.info("End worker")
예제 #11
0
def main(
    scheduler,
    host,
    nthreads,
    name,
    memory_limit,
    device_memory_limit,
    rmm_pool_size,
    rmm_maximum_pool_size,
    rmm_managed_memory,
    rmm_async,
    rmm_log_directory,
    pid_file,
    resources,
    dashboard,
    dashboard_address,
    local_directory,
    shared_filesystem,
    scheduler_file,
    interface,
    preload,
    dashboard_prefix,
    tls_ca_file,
    tls_cert,
    tls_key,
    enable_tcp_over_ucx,
    enable_infiniband,
    enable_nvlink,
    enable_rdmacm,
    net_devices,
    enable_jit_unspill,
    worker_class,
    **kwargs,
):
    if tls_ca_file and tls_cert and tls_key:
        security = Security(
            tls_ca_file=tls_ca_file,
            tls_worker_cert=tls_cert,
            tls_worker_key=tls_key,
        )
    else:
        security = None

    if isinstance(scheduler, str) and scheduler.startswith("-"):
        raise ValueError(
            "The scheduler address can't start with '-'. Please check "
            "your command line arguments, you probably attempted to use "
            "unsupported one. Scheduler address: %s" % scheduler)

    if worker_class is not None:
        worker_class = import_term(worker_class)

    worker = CUDAWorker(
        scheduler,
        host,
        nthreads,
        name,
        memory_limit,
        device_memory_limit,
        rmm_pool_size,
        rmm_maximum_pool_size,
        rmm_managed_memory,
        rmm_async,
        rmm_log_directory,
        pid_file,
        resources,
        dashboard,
        dashboard_address,
        local_directory,
        shared_filesystem,
        scheduler_file,
        interface,
        preload,
        dashboard_prefix,
        security,
        enable_tcp_over_ucx,
        enable_infiniband,
        enable_nvlink,
        enable_rdmacm,
        net_devices,
        enable_jit_unspill,
        worker_class,
        **kwargs,
    )

    async def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        await worker.close()

    async def run():
        await worker
        await worker.finished()

    loop = IOLoop.current()

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")