def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None:
        if self.memory_pause_fraction is False:
            return

        assert self.memory_limit
        frac = memory / self.memory_limit
        # Pause worker threads if above 80% memory use
        if frac > self.memory_pause_fraction:
            # Try to free some memory while in paused state
            self._throttled_gc.collect()
            if worker.status == Status.running:
                logger.warning(
                    "Worker is at %d%% memory usage. Pausing worker.  "
                    "Process memory: %s -- Worker memory limit: %s",
                    int(frac * 100),
                    format_bytes(memory),
                    format_bytes(self.memory_limit)
                    if self.memory_limit is not None
                    else "None",
                )
                worker.status = Status.paused
        elif worker.status == Status.paused:
            logger.warning(
                "Worker is at %d%% memory usage. Resuming worker. "
                "Process memory: %s -- Worker memory limit: %s",
                int(frac * 100),
                format_bytes(memory),
                format_bytes(self.memory_limit)
                if self.memory_limit is not None
                else "None",
            )
            worker.status = Status.running
示例#2
0
def test_pause_executor(c, s, a):
    memory = psutil.Process().memory_info().rss
    a.memory_limit = memory / 0.5 + 200e6
    np = pytest.importorskip("numpy")

    def f():
        x = np.ones(int(400e6), dtype="u1")
        sleep(1)

    with captured_logger(logging.getLogger("distributed.worker")) as logger:
        future = c.submit(f)
        futures = c.map(slowinc, range(30), delay=0.1)

        start = time()
        while not a.paused:
            yield gen.sleep(0.01)
            assert time() < start + 4, (
                format_bytes(psutil.Process().memory_info().rss),
                format_bytes(a.memory_limit),
                len(a.data),
            )
        out = logger.getvalue()
        assert "memory" in out.lower()
        assert "pausing" in out.lower()

    assert sum(f.status == "finished" for f in futures) < 4

    yield wait(futures)
示例#3
0
def test_ucx_config_w_env_var(cleanup, loop, monkeypatch):
    size = "1000.00 MB"
    monkeypatch.setenv("DASK_RMM__POOL_SIZE", size)

    dask.config.refresh()

    port = "13339"
    sched_addr = "ucx://%s:%s" % (HOST, port)

    with popen([
            "dask-scheduler", "--no-dashboard", "--protocol", "ucx", "--port",
            port
    ]) as sched:
        with popen([
                "dask-worker",
                sched_addr,
                "--no-dashboard",
                "--protocol",
                "ucx",
                "--no-nanny",
        ]) as w:
            with Client(sched_addr, loop=loop, timeout=10) as c:
                while not c.scheduler_info()["workers"]:
                    sleep(0.1)

                # configured with 1G pool
                rmm_usage = c.run_on_scheduler(rmm.get_info)
                assert size == format_bytes(rmm_usage.free)

                # configured with 1G pool
                worker_addr = list(c.scheduler_info()["workers"])[0]
                worker_rmm_usage = c.run(rmm.get_info)
                rmm_usage = worker_rmm_usage[worker_addr]
                assert size == format_bytes(rmm_usage.free)
示例#4
0
def main(args):
    # Set up workers on the local machine
    cluster = LocalCUDACluster(protocol=args.protocol,
                               n_workers=args.n_workers,
                               CUDA_VISIBLE_DEVICES=args.devs)
    client = Client(cluster)

    if args.no_pool_allocator:
        client.run(cudf.set_allocator, "default", pool=False)
    else:
        client.run(cudf.set_allocator, "default", pool=True)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(args, write_profile=None))
    took_list.append(run(
        args, write_profile=args.profile))  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(
        lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {(cluster.scheduler.workers[w1].name,
                   cluster.scheduler.workers[w2].name): [
                       "%s/s" % format_bytes(x)
                       for x in numpy.quantile(v, [0.25, 0.50, 0.75])
                   ]
                  for (w1, w2), v in bandwidths.items()}
    total_nbytes = {(
        cluster.scheduler.workers[w1].name,
        cluster.scheduler.workers[w2].name,
    ): format_bytes(sum(nb))
                    for (w1, w2), nb in total_nbytes.items()}

    print("Merge benchmark")
    print("--------------------------")
    print(f"Chunk-size  | {args.chunk_size}")
    print(f"Frac-match  | {args.frac_match}")
    print(f"Ignore-size | {format_bytes(args.ignore_size)}")
    print(f"Protocol    | {args.protocol}")
    print(f"Device(s)   | {args.devs}")
    print("==========================")
    for took in took_list:
        print(f"Total time  | {format_time(took)}")
    print("==========================")
    print("(w1,w2)     | 25% 50% 75% (total nbytes)")
    print("--------------------------")
    for (d1, d2), bw in sorted(bandwidths.items()):
        print("(%02d,%02d)     | %s %s %s (%s)" %
              (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
示例#5
0
    def update(self):
        with log_errors():
            workers = list(self.scheduler.workers.values())

            utilization = []
            memory = []
            gpu_index = []
            y = []
            memory_total = 0
            memory_max = 0
            worker = []
            i = 0

            for ws in workers:
                info = ws.extra["gpu"]
                metrics = ws.metrics["gpu"]
                for j, (u, mem_used, mem_total) in enumerate(
                        zip(
                            metrics["utilization"],
                            metrics["memory-used"],
                            info["memory-total"],
                        )):
                    memory_max = max(memory_max, mem_total)
                    memory_total += mem_total
                    utilization.append(int(u))
                    memory.append(mem_used)
                    worker.append(ws.address)
                    gpu_index.append(j)
                    y.append(i)
                    i += 1

            memory_text = [format_bytes(m) for m in memory]

            result = {
                "memory": memory,
                "memory-half": [m / 2 for m in memory],
                "memory_text": memory_text,
                "utilization": utilization,
                "utilization-half": [u / 2 for u in utilization],
                "worker": worker,
                "gpu-index": gpu_index,
                "y": y,
                "escaped_worker": [escape.url_escape(w) for w in worker],
            }

            self.memory_figure.title.text = "GPU Memory: %s / %s" % (
                format_bytes(sum(memory)),
                format_bytes(memory_total),
            )
            self.memory_figure.x_range.end = memory_max

            update(self.source, result)
示例#6
0
    def update(self):
        with log_errors():
            workers = list(self.scheduler.workers.values())

            utilization = []
            memory = []
            gpu_index = []
            y = []
            memory_total = 0
            memory_max = 0
            worker = []

            for idx, ws in enumerate(workers):
                try:
                    info = ws.extra["gpu"]
                except KeyError:
                    continue
                metrics = ws.metrics["gpu"]
                u = metrics["utilization"]
                mem_used = metrics["memory-used"]
                mem_total = info["memory-total"]
                memory_max = max(memory_max, mem_total)
                memory_total += mem_total
                utilization.append(int(u))
                memory.append(mem_used)
                worker.append(ws.address)
                gpu_index.append(idx)
                y.append(idx)

            memory_text = [format_bytes(m) for m in memory]

            result = {
                "memory": memory,
                "memory-half": [m / 2 for m in memory],
                "memory_text": memory_text,
                "utilization": utilization,
                "utilization-half": [u / 2 for u in utilization],
                "worker": worker,
                "gpu-index": gpu_index,
                "y": y,
                "escaped_worker": [escape.url_escape(w) for w in worker],
            }

            self.memory_figure.title.text = "GPU Memory: {} / {}".format(
                format_bytes(sum(memory)),
                format_bytes(memory_total),
            )
            self.memory_figure.x_range.end = memory_max

            update(self.source, result)
示例#7
0
    def update(self):
        with log_errors():
            outgoing = self.worker.outgoing_transfer_log
            n = self.worker.outgoing_count - self.last_outgoing
            outgoing = [outgoing[-i].copy() for i in range(1, n + 1)]
            self.last_outgoing = self.worker.outgoing_count

            incoming = self.worker.incoming_transfer_log
            n = self.worker.incoming_count - self.last_incoming
            incoming = [incoming[-i].copy() for i in range(1, n + 1)]
            self.last_incoming = self.worker.incoming_count

            for [msgs, source] in [
                [incoming, self.incoming],
                [outgoing, self.outgoing],
            ]:

                for msg in msgs:
                    if "compressed" in msg:
                        del msg["compressed"]
                    del msg["keys"]

                    bandwidth = msg["total"] / (msg["duration"] or 0.5)
                    bw = max(min(bandwidth / 500e6, 1), 0.3)
                    msg["alpha"] = bw
                    try:
                        msg["y"] = self.who[msg["who"]]
                    except KeyError:
                        self.who[msg["who"]] = len(self.who)
                        msg["y"] = self.who[msg["who"]]

                    msg["hover"] = "%s / %s = %s/s" % (
                        format_bytes(msg["total"]),
                        format_time(msg["duration"]),
                        format_bytes(msg["total"] / msg["duration"]),
                    )

                    for k in ["middle", "duration", "start", "stop"]:
                        msg[k] = msg[k] * 1000

                if msgs:
                    msgs = transpose(msgs)
                    if (
                        len(source.data["stop"])
                        and min(msgs["start"]) > source.data["stop"][-1] + 10000
                    ):
                        source.data.update(msgs)
                    else:
                        source.stream(msgs, rollover=10000)
示例#8
0
def test_filters():
    template = get_template("bytes.html.j2")
    assert format_bytes in FILTERS.values()
    assert format_bytes(2e9) in template.render(foo=2e9)

    template = get_template("custom_filter.html.j2")
    assert "baz" in template.render(foo=None)
示例#9
0
def test_job_script(tmpdir):
    log_directory = tmpdir.strpath
    with SGECluster(
            cores=6,
            processes=2,
            memory="12GB",
            queue="my-queue",
            project="my-project",
            walltime="02:00:00",
            env_extra=["export MY_VAR=my_var"],
            job_extra=["-w e", "-m e"],
            log_directory=log_directory,
            resource_spec="h_vmem=12G,mem_req=12G",
    ) as cluster:
        job_script = cluster.job_script()
        formatted_bytes = format_bytes(parse_bytes("6GB")).replace(" ", "")

        for each in [
                "--nprocs 2",
                "--nthreads 3",
                f"--memory-limit {formatted_bytes}",
                "-q my-queue",
                "-P my-project",
                "-l h_rt=02:00:00",
                "export MY_VAR=my_var",
                "#$ -w e",
                "#$ -m e",
                "#$ -e {}".format(log_directory),
                "#$ -o {}".format(log_directory),
                "-l h_vmem=12G,mem_req=12G",
                "#$ -cwd",
                "#$ -j y",
        ]:
            assert each in job_script
示例#10
0
    def _widget_status(self):
        workers = len(self.scheduler.workers)
        cores = sum(ws.nthreads for ws in self.scheduler.workers.values())
        memory = sum(ws.memory_limit for ws in self.scheduler.workers.values())
        memory = format_bytes(memory)
        text = """
<div>
  <style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
  </style>
  <table style="text-align: right;">
    <tr><th>Workers</th> <td>%d</td></tr>
    <tr><th>Cores</th> <td>%d</td></tr>
    <tr><th>Memory</th> <td>%s</td></tr>
  </table>
</div>
""" % (
            workers,
            cores,
            memory,
        )
        return text
示例#11
0
 def oom(nbytes: int) -> bool:
     """Try to handle an out-of-memory error by spilling"""
     memory_freed = self.manager.evict(
         nbytes=nbytes,
         proxies_access=self.manager._dev.buffer_info,
         serializer=lambda p: p._pxy_serialize(serializers=(
             "dask", "pickle")),
     )
     gc.collect()
     if memory_freed > 0:
         return True  # Ask RMM to retry the allocation
     else:
         with io.StringIO() as f:
             traceback.print_stack(file=f)
             f.seek(0)
             tb = f.read()
         self.logger.warning(
             "RMM allocation of %s failed, spill-on-demand couldn't "
             "find any device memory to spill:\n%s\ntraceback:\n%s\n",
             format_bytes(nbytes),
             self.manager.pprint(),
             tb,
         )
         # Since we didn't find anything to spill, we give up.
         return False
示例#12
0
def test_job_script():
    with OARCluster(walltime="00:02:00", processes=4, cores=8,
                    memory="28GB") as cluster:
        job_script = cluster.job_script()
        assert "#OAR" in job_script
        assert "#OAR -n dask-worker" in job_script
        formatted_bytes = format_bytes(parse_bytes("7GB")).replace(" ", "")
        assert f"--memory-limit {formatted_bytes}" in job_script
        assert "#OAR -l /nodes=1/core=8,walltime=00:02:00" in job_script
        assert "#OAR --project" not in job_script
        assert "#OAR -q" not in job_script

        assert "export " not in job_script

        assert ("{} -m distributed.cli.dask_worker tcp://".format(
            sys.executable) in job_script)
        formatted_bytes = format_bytes(parse_bytes("7GB")).replace(" ", "")
        assert f"--nthreads 2 --nprocs 4 --memory-limit {formatted_bytes}" in job_script

    with OARCluster(
            walltime="00:02:00",
            processes=4,
            cores=8,
            memory="28GB",
            env_extra=[
                'export LANG="en_US.utf8"',
                'export LANGUAGE="en_US.utf8"',
                'export LC_ALL="en_US.utf8"',
            ],
    ) as cluster:
        job_script = cluster.job_script()
        assert "#OAR" in job_script
        assert "#OAR -n dask-worker" in job_script
        formatted_bytes = format_bytes(parse_bytes("7GB")).replace(" ", "")
        assert f"--memory-limit {formatted_bytes}" in job_script
        assert "#OAR -l /nodes=1/core=8,walltime=00:02:00" in job_script
        assert "#OAR --project" not in job_script
        assert "#OAR -q" not in job_script

        assert 'export LANG="en_US.utf8"' in job_script
        assert 'export LANGUAGE="en_US.utf8"' in job_script
        assert 'export LC_ALL="en_US.utf8"' in job_script

        assert ("{} -m distributed.cli.dask_worker tcp://".format(
            sys.executable) in job_script)
        formatted_bytes = format_bytes(parse_bytes("7GB")).replace(" ", "")
        assert f"--nthreads 2 --nprocs 4 --memory-limit {formatted_bytes}" in job_script
示例#13
0
def cache_all_catalog_items(cat, cache_storage=None):
    """Cache all catalog items from catalog `cat` that contain `cache::` in url to
    folder optionally specified by `cache_storage`.

    Example:
        >>> cat = intake.open_catalog('master.yaml')
        >>> cache_all_catalog_items(cat, 'test_cache_folder')
        >>> os.path.exists('test_cache_folder/HadCRUT.4.6.0.0.median.nc')
    """
    fsspec.config.conf["simplecache"] = {"same_names": True}
    if cache_storage:
        fsspec.config.conf["simplecache"] = {
            "cache_storage": cache_storage,
            "same_names": True,
        }
        print("fsspec.config.conf", fsspec.config.conf, "\n")
    for item_str in cat.walk(depth=2):
        if "CRU_TS" not in item_str:
            item = getattr(cat, item_str)
            if not isinstance(cat[item_str],
                              intake.catalog.local.YAMLFileCatalog):
                if "cache::" in item.urlpath:
                    filename = item.urlpath.split("/")[-1]
                    print(f"try to cache {item_str} from {item.urlpath} to"
                          f"{cache_storage}/{filename}")
                    try:
                        if isinstance(
                                item,
                            (
                                intake_geopandas.RegionmaskSource,
                                intake_geopandas.GeoPandasFileSource,
                            ),
                        ):
                            ds = item.read()
                            print(
                                item_str,
                                "\n",
                                type(item),
                                "\n",
                                ds.dims,
                                "\n",
                                ds.coords,
                                "\nsize = ",
                                format_bytes(ds.nbytes),
                            )
                        else:
                            ds = item.to_dask()
                            print(item_str, "\n", type(item))
                        if cache_storage:
                            assert filename in os.listdir(
                                f"{cache_storage}"), print(
                                    item_str, "caching failed:",
                                    os.listdir(cache_storage))
                        print("successful", item_str, "\n")
                    except Exception as e:
                        print(
                            f"{item_str} failed, error {type(e).__name__}: {e}\n"
                        )
示例#14
0
    def _widget_status(self):
        try:
            workers = self.scheduler_info["workers"]
        except KeyError:
            return None
        else:
            n_workers = len(workers)
            cores = sum(w["nthreads"] for w in workers.values())
            memory = sum(w["memory_limit"] for w in workers.values())

            return _widget_status_template % (n_workers, cores, format_bytes(memory))
示例#15
0
 def _gc_callback(self, phase, info):
     # Young generations are small and collected very often,
     # don't waste time measuring them
     if info["generation"] != 2:
         return
     if self._proc is not None:
         rss = self._proc.memory_info().rss
     else:
         rss = 0
     if phase == "start":
         self._fractional_timer.start_timing()
         self._gc_rss_before = rss
         return
     assert phase == "stop"
     self._fractional_timer.stop_timing()
     frac = self._fractional_timer.running_fraction
     if frac is not None and frac >= self._warn_over_frac:
         logger.warning(
             "full garbage collections took %d%% CPU time "
             "recently (threshold: %d%%)",
             100 * frac,
             100 * self._warn_over_frac,
         )
     rss_saved = self._gc_rss_before - rss
     if rss_saved >= self._info_over_rss_win:
         logger.info(
             "full garbage collection released %s "
             "from %d reference cycles (threshold: %s)",
             format_bytes(rss_saved),
             info["collected"],
             format_bytes(self._info_over_rss_win),
         )
     if info["uncollectable"] > 0:
         # This should ideally never happen on Python 3, but who knows?
         logger.warning(
             "garbage collector couldn't collect %d objects, "
             "please look in gc.garbage",
             info["uncollectable"],
         )
示例#16
0
def test_job_script(Cluster):
    with Cluster(walltime="00:02:00", processes=4, cores=8,
                 memory="28GB") as cluster:

        job_script = cluster.job_script()
        assert "#PBS" in job_script
        assert "#PBS -N dask-worker" in job_script
        assert "#PBS -l select=1:ncpus=8:mem=27GB" in job_script
        assert "#PBS -l walltime=00:02:00" in job_script
        assert "#PBS -q" not in job_script
        assert "#PBS -A" not in job_script

        assert ("{} -m distributed.cli.dask_worker tcp://".format(
            sys.executable) in job_script)
        formatted_bytes = format_bytes(parse_bytes("7GB")).replace(" ", "")
        assert f"--nthreads 2 --nprocs 4 --memory-limit {formatted_bytes}" in job_script

    with Cluster(
            queue="regular",
            project="DaskOnPBS",
            processes=4,
            cores=8,
            resource_spec="select=1:ncpus=24:mem=100GB",
            memory="28GB",
    ) as cluster:

        job_script = cluster.job_script()
        assert "#PBS -q regular" in job_script
        assert "#PBS -N dask-worker" in job_script
        assert "#PBS -l select=1:ncpus=24:mem=100GB" in job_script
        assert "#PBS -l select=1:ncpus=8:mem=27GB" not in job_script
        assert "#PBS -l walltime=" in job_script
        assert "#PBS -A DaskOnPBS" in job_script

        assert ("{} -m distributed.cli.dask_worker tcp://".format(
            sys.executable) in job_script)
        formatted_bytes = format_bytes(parse_bytes("7GB")).replace(" ", "")
        assert f"--nthreads 2 --nprocs 4 --memory-limit {formatted_bytes}" in job_script
示例#17
0
    def __repr__(self):
        text = "%s(%r, workers=%d, threads=%d" % (
            self._cluster_class_name,
            self.scheduler_address,
            len(self.scheduler_info["workers"]),
            sum(w["nthreads"] for w in self.scheduler_info["workers"].values()),
        )

        memory = [w["memory_limit"] for w in self.scheduler_info["workers"].values()]
        if all(memory):
            text += ", memory=" + format_bytes(sum(memory))

        text += ")"
        return text
示例#18
0
    def __repr__(self):
        text = "%s(%r, workers=%d, threads=%d" % (
            getattr(self, "_name",
                    type(self).__name__),
            self.scheduler_address,
            len(self.workers),
            sum(w["nthreads"]
                for w in self.scheduler_info["workers"].values()),
        )

        memory = [
            w["memory_limit"] for w in self.scheduler_info["workers"].values()
        ]
        if all(memory):
            text += ", memory=" + format_bytes(sum(memory))

        text += ")"
        return text
示例#19
0
    def _widget_status(self):
        workers = len(self.scheduler_info["workers"])
        if hasattr(self, "worker_spec"):
            requested = sum(1 if "group" not in each else len(each["group"])
                            for each in self.worker_spec.values())
        elif hasattr(self, "workers"):
            requested = len(self.workers)
        else:
            requested = workers
        cores = sum(v["nthreads"]
                    for v in self.scheduler_info["workers"].values())
        memory = sum(v["memory_limit"]
                     for v in self.scheduler_info["workers"].values())
        memory = format_bytes(memory)
        text = """
<div>
  <style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
  </style>
  <table style="text-align: right;">
    <tr> <th>Workers</th> <td>%s</td></tr>
    <tr> <th>Cores</th> <td>%d</td></tr>
    <tr> <th>Memory</th> <td>%s</td></tr>
  </table>
</div>
""" % (
            workers if workers == requested else "%d / %d" %
            (workers, requested),
            cores,
            memory,
        )
        return text
示例#20
0
def test_job_script():
    with HTCondorCluster(
            cores=4,
            processes=2,
            memory="100MB",
            disk="100MB",
            env_extra=[
                'export LANG="en_US.utf8"', 'export LC_ALL="en_US.utf8"'
            ],
            job_extra={"+Extra": "True"},
            submit_command_extra=["-verbose"],
            cancel_command_extra=["-forcex"],
    ) as cluster:
        job_script = cluster.job_script()
        assert "RequestCpus = MY.DaskWorkerCores" in job_script
        assert "RequestDisk = floor(MY.DaskWorkerDisk / 1024)" in job_script
        assert "RequestMemory = floor(MY.DaskWorkerMemory / 1048576)" in job_script
        assert "MY.DaskWorkerCores = 4" in job_script
        assert "MY.DaskWorkerDisk = 100000000" in job_script
        assert "MY.DaskWorkerMemory = 100000000" in job_script
        assert 'MY.JobId = "$(ClusterId).$(ProcId)"' in job_script
        assert "LANG=en_US.utf8" in job_script
        assert "LC_ALL=en_US.utf8" in job_script
        assert "export" not in job_script
        assert "+Extra = True" in job_script
        assert re.search(r"condor_submit\s.*-verbose",
                         cluster._dummy_job.submit_command)
        assert re.search(r"condor_rm\s.*-forcex",
                         cluster._dummy_job.cancel_command)

        assert ("{} -m distributed.cli.dask_worker tcp://".format(
            sys.executable) in job_script)
        formatted_bytes = format_bytes(parse_bytes("50MB")).replace(" ", "")
        assert f"--memory-limit {formatted_bytes}" in job_script
        assert "--nthreads 2" in job_script
        assert "--nprocs 2" in job_script
示例#21
0
def test_format_bytes(n, expect):
    assert format_bytes(int(n)) == expect
async def run(args):

    # Set up workers on the local machine
    async with LocalCUDACluster(
            protocol=args.protocol,
            n_workers=len(args.devs.split(",")),
            CUDA_VISIBLE_DEVICES=args.devs,
            ucx_net_devices="auto",
            enable_infiniband=True,
            enable_nvlink=True,
            asynchronous=True,
    ) as cluster:
        async with Client(cluster, asynchronous=True) as client:

            def _worker_setup(size=None):
                import rmm

                rmm.reinitialize(
                    pool_allocator=not args.no_rmm_pool,
                    devices=0,
                    initial_pool_size=size,
                )
                cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)

            await client.run(_worker_setup)
            # Create an RMM pool on the scheduler due to occasional deserialization
            # of CUDA objects. May cause issues with InfiniBand otherwise.
            await client.run_on_scheduler(_worker_setup, 1e9)

            # Create a simple random array
            rs = da.random.RandomState(RandomState=cupy.random.RandomState)
            x = rs.random((args.size, args.size),
                          chunks=args.chunk_size).persist()
            await wait(x)

            # Execute the operations to benchmark
            if args.profile is not None:
                async with performance_report(filename=args.profile):
                    t1 = clock()
                    await client.compute((x + x.T).sum())
                    took = clock() - t1
            else:
                t1 = clock()
                await client.compute((x + x.T).sum())
                took = clock() - t1

            # Collect, aggregate, and print peer-to-peer bandwidths
            incoming_logs = await client.run(
                lambda dask_worker: dask_worker.incoming_transfer_log)
            bandwidths = defaultdict(list)
            total_nbytes = defaultdict(list)
            for k, L in incoming_logs.items():
                for d in L:
                    if d["total"] >= args.ignore_size:
                        bandwidths[k, d["who"]].append(d["bandwidth"])
                        total_nbytes[k, d["who"]].append(d["total"])
            bandwidths = {(
                cluster.scheduler.workers[w1].name,
                cluster.scheduler.workers[w2].name,
            ): [
                "%s/s" % format_bytes(x)
                for x in np.quantile(v, [0.25, 0.50, 0.75])
            ]
                          for (w1, w2), v in bandwidths.items()}
            total_nbytes = {(
                cluster.scheduler.workers[w1].name,
                cluster.scheduler.workers[w2].name,
            ): format_bytes(sum(nb))
                            for (w1, w2), nb in total_nbytes.items()}

            print("Roundtrip benchmark")
            print("--------------------------")
            print(f"Size        | {args.size}*{args.size}")
            print(f"Chunk-size  | {args.chunk_size}")
            print(f"Ignore-size | {format_bytes(args.ignore_size)}")
            print(f"Protocol    | {args.protocol}")
            print(f"Device(s)   | {args.devs}")
            print(f"npartitions | {x.npartitions}")
            print("==========================")
            print(f"Total time  | {format_time(took)}")
            print("==========================")
            print("(w1,w2)     | 25% 50% 75% (total nbytes)")
            print("--------------------------")
            for (d1, d2), bw in sorted(bandwidths.items()):
                print("(%02d,%02d)     | %s %s %s (%s)" %
                      (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
示例#23
0
# # The workers
nb_workers=4
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=nb_workers)
c = Client(cluster)

c

from dask.utils import ensure_dict, format_bytes
    
wk = c.scheduler_info()["workers"]

text="Workers= " + str(len(wk))
memory = [w["memory_limit"] for w in wk.values()]
cores = sum(w["nthreads"] for w in wk.values())
text += ", Cores=" + str(cores)
if all(memory):
    text += ", Memory=" + format_bytes(sum(memory))
print(text)


# # The data

%time ds=xr.open_zarr('/mnt/alberta/equipes/IGE/meom/workdir/albert/eNATL60/zarr/eNATL60-BLBT02-SSH-1h')
#3,6s
mean=ds.sossheig.mean(dim='time_counter')
%time mean.load()
#1h51
cluster.close()
示例#24
0
 def worker_process_memory(self):
     mem = format_bytes(self.worker_memory / self.worker_processes)
     mem = mem.replace(" ", "")
     return mem
示例#25
0
def main(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    if args.sched_addr:
        client = Client(args.sched_addr)
    else:
        filterwarnings("ignore",
                       message=".*NVLink.*rmm_pool_size.*",
                       category=UserWarning)

        cluster = Cluster(*cluster_args, **cluster_kwargs)
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        client = Client(scheduler_addr if args.multi_node else cluster)

    if args.type == "gpu":
        client.run(
            setup_memory_pool,
            pool_size=args.rmm_pool_size,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )
        # Create an RMM pool on the scheduler due to occasional deserialization
        # of CUDA objects. May cause issues with InfiniBand otherwise.
        client.run_on_scheduler(
            setup_memory_pool,
            pool_size=1e9,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )

    scheduler_workers = client.run_on_scheduler(get_scheduler_workers)
    n_workers = len(scheduler_workers)
    client.wait_for_workers(n_workers)

    if args.all_to_all:
        all_to_all(client)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, n_workers, write_profile=None))
    took_list.append(
        run(client, args, n_workers,
            write_profile=args.profile))  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(
        lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {(scheduler_workers[w1].name, scheduler_workers[w2].name): [
        "%s/s" % format_bytes(x)
        for x in numpy.quantile(v, [0.25, 0.50, 0.75])
    ]
                  for (w1, w2), v in bandwidths.items()}
    total_nbytes = {(
        scheduler_workers[w1].name,
        scheduler_workers[w2].name,
    ): format_bytes(sum(nb))
                    for (w1, w2), nb in total_nbytes.items()}

    t_runs = numpy.empty(len(took_list))
    if args.markdown:
        print("```")
    print("Shuffle benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"partition-size | {format_bytes(args.partition_size)}")
    print(f"in-parts       | {args.in_parts}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    if args.device_memory_limit:
        print(f"memory-limit   | {format_bytes(args.device_memory_limit)}")
    print(f"rmm-pool       | {(not args.disable_rmm_pool)}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for idx, (data_processed, took) in enumerate(took_list):
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
        t_runs[idx] = float(format_bytes(throughput).split(" ")[0])
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.plot is not None:
        plot_benchmark(t_runs, args.plot, historical=True)

    if args.backend == "dask":
        if args.markdown:
            print(
                "<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```"
            )
        print("(w1,w2)        | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            fmt = ("(%s,%s)        | %s %s %s (%s)" if args.multi_node or
                   args.sched_addr else "(%02d,%02d)        | %s %s %s (%s)")
            print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
        if args.markdown:
            print("```\n</details>\n")

    if args.benchmark_json:
        bandwidths_json = {
            "bandwidth_({d1},{d2})_{i}" if args.multi_node or args.sched_addr
            else "(%02d,%02d)_%s" % (d1, d2, i): parse_bytes(v.rstrip("/s"))
            for (d1, d2), bw in sorted(bandwidths.items()) for i, v in zip(
                ["25%", "50%", "75%", "total_nbytes"],
                [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
            )
        }

        with open(args.benchmark_json, "a") as fp:
            for data_processed, took in took_list:
                fp.write(
                    dumps(
                        dict(
                            {
                                "backend": args.backend,
                                "partition_size": args.partition_size,
                                "in_parts": args.in_parts,
                                "protocol": args.protocol,
                                "devs": args.devs,
                                "device_memory_limit":
                                args.device_memory_limit,
                                "rmm_pool": not args.disable_rmm_pool,
                                "tcp": args.enable_tcp_over_ucx,
                                "ib": args.enable_infiniband,
                                "nvlink": args.enable_nvlink,
                                "data_processed": data_processed,
                                "wall_clock": took,
                                "throughput": data_processed / took,
                            },
                            **bandwidths_json,
                        )) + "\n")

    if args.multi_node:
        client.shutdown()
        client.close()
示例#26
0
    async def _maybe_spill(self, worker: Worker, memory: int) -> None:
        if self.memory_spill_fraction is False:
            return

        # SpillBuffer or a duct-type compatible MutableMapping which offers the
        # fast property and evict() methods. Dask-CUDA uses this.
        if not hasattr(self.data, "fast") or not hasattr(self.data, "evict"):
            return
        data = cast(ManualEvictProto, self.data)

        assert self.memory_limit
        frac = memory / self.memory_limit
        if frac <= self.memory_spill_fraction:
            return

        total_spilled = 0
        logger.debug(
            "Worker is at %.0f%% memory usage. Start spilling data to disk.",
            frac * 100,
        )
        # Implement hysteresis cycle where spilling starts at the spill threshold and
        # stops at the target threshold. Normally that here the target threshold defines
        # process memory, whereas normally it defines reported managed memory (e.g.
        # output of sizeof() ). If target=False, disable hysteresis.
        target = self.memory_limit * (
            self.memory_target_fraction or self.memory_spill_fraction
        )
        count = 0
        need = memory - target
        while memory > target:
            if not data.fast:
                logger.warning(
                    "Unmanaged memory use is high. This may indicate a memory leak "
                    "or the memory may not be released to the OS; see "
                    "https://distributed.dask.org/en/latest/worker.html#memtrim "
                    "for more information. "
                    "-- Unmanaged memory: %s -- Worker memory limit: %s",
                    format_bytes(memory),
                    format_bytes(self.memory_limit),
                )
                break

            weight = data.evict()
            if weight == -1:
                # Failed to evict:
                # disk full, spill size limit exceeded, or pickle error
                break

            total_spilled += weight
            count += 1
            await asyncio.sleep(0)

            memory = worker.monitor.get_process_memory()
            if total_spilled > need and memory > target:
                # Issue a GC to ensure that the evicted data is actually
                # freed from memory and taken into account by the monitor
                # before trying to evict even more data.
                self._throttled_gc.collect()
                memory = worker.monitor.get_process_memory()

        self._maybe_pause_or_unpause(worker, memory)
        if count:
            logger.debug(
                "Moved %d tasks worth %s to disk",
                count,
                format_bytes(total_spilled),
            )
示例#27
0
def main(args):
    # Set up workers on the local machine
    if args.protocol == "tcp":
        cluster = LocalCUDACluster(
            protocol=args.protocol,
            n_workers=args.n_workers,
            CUDA_VISIBLE_DEVICES=args.devs,
        )
    else:
        enable_infiniband = args.enable_infiniband
        enable_nvlink = args.enable_nvlink
        enable_tcp_over_ucx = args.enable_tcp_over_ucx
        cluster = LocalCUDACluster(
            protocol=args.protocol,
            n_workers=args.n_workers,
            CUDA_VISIBLE_DEVICES=args.devs,
            ucx_net_devices="auto",
            enable_tcp_over_ucx=enable_tcp_over_ucx,
            enable_infiniband=enable_infiniband,
            enable_nvlink=enable_nvlink,
        )
        initialize(
            create_cuda_context=True,
            enable_tcp_over_ucx=enable_tcp_over_ucx,
            enable_infiniband=enable_infiniband,
            enable_nvlink=enable_nvlink,
        )
    client = Client(cluster)

    def _worker_setup(initial_pool_size=None):
        import rmm

        rmm.reinitialize(
            pool_allocator=not args.no_rmm_pool,
            devices=0,
            initial_pool_size=initial_pool_size,
        )
        cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)

    client.run(_worker_setup)
    # Create an RMM pool on the scheduler due to occasional deserialization
    # of CUDA objects. May cause issues with InfiniBand otherwise.
    client.run_on_scheduler(_worker_setup, 1e9)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, write_profile=None))
    took_list.append(
        run(client, args, write_profile=args.profile)
    )  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {
        (cluster.scheduler.workers[w1].name, cluster.scheduler.workers[w2].name): [
            "%s/s" % format_bytes(x) for x in numpy.quantile(v, [0.25, 0.50, 0.75])
        ]
        for (w1, w2), v in bandwidths.items()
    }
    total_nbytes = {
        (
            cluster.scheduler.workers[w1].name,
            cluster.scheduler.workers[w2].name,
        ): format_bytes(sum(nb))
        for (w1, w2), nb in total_nbytes.items()
    }

    if args.markdown:
        print("```")
    print("Merge benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"rows-per-chunk | {args.chunk_size}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    print(f"rmm-pool       | {(not args.no_rmm_pool)}")
    print(f"frac-match     | {args.frac_match}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for data_processed, took in took_list:
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.backend == "dask":
        if args.markdown:
            print("<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```")
        print("(w1,w2)     | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            print(
                "(%02d,%02d)     | %s %s %s (%s)"
                % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)])
            )
        if args.markdown:
            print("```\n</details>\n")
示例#28
0
def main(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    if args.sched_addr:
        client = Client(args.sched_addr)
    else:
        filterwarnings("ignore",
                       message=".*NVLink.*rmm_pool_size.*",
                       category=UserWarning)

        cluster = Cluster(*cluster_args, **cluster_kwargs)
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        client = Client(scheduler_addr if args.multi_node else cluster)

    if args.type == "gpu":
        client.run(
            setup_memory_pool,
            pool_size=args.rmm_pool_size,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )
        # Create an RMM pool on the scheduler due to occasional deserialization
        # of CUDA objects. May cause issues with InfiniBand otherwise.
        client.run_on_scheduler(
            setup_memory_pool,
            pool_size=1e9,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )

    scheduler_workers = client.run_on_scheduler(get_scheduler_workers)
    n_workers = len(scheduler_workers)
    client.wait_for_workers(n_workers)

    # Allow the number of chunks to vary between
    # the "base" and "other" DataFrames
    args.base_chunks = args.base_chunks or n_workers
    args.other_chunks = args.other_chunks or n_workers

    if args.all_to_all:
        all_to_all(client)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, n_workers, write_profile=None))
    took_list.append(
        run(client, args, n_workers,
            write_profile=args.profile))  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(
        lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {(scheduler_workers[w1].name, scheduler_workers[w2].name): [
        "%s/s" % format_bytes(x)
        for x in numpy.quantile(v, [0.25, 0.50, 0.75])
    ]
                  for (w1, w2), v in bandwidths.items()}
    total_nbytes = {(
        scheduler_workers[w1].name,
        scheduler_workers[w2].name,
    ): format_bytes(sum(nb))
                    for (w1, w2), nb in total_nbytes.items()}

    broadcast = (False if args.shuffle_join else
                 (True if args.broadcast_join else "default"))

    t_runs = numpy.empty(len(took_list))
    if args.markdown:
        print("```")
    print("Merge benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"merge type     | {args.type}")
    print(f"rows-per-chunk | {args.chunk_size}")
    print(f"base-chunks    | {args.base_chunks}")
    print(f"other-chunks   | {args.other_chunks}")
    print(f"broadcast      | {broadcast}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    print(f"rmm-pool       | {(not args.disable_rmm_pool)}")
    print(f"frac-match     | {args.frac_match}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for idx, (data_processed, took) in enumerate(took_list):
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
        t_runs[idx] = float(format_bytes(throughput).split(" ")[0])
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.plot is not None:
        plot_benchmark(t_runs, args.plot, historical=True)

    if args.backend == "dask":
        if args.markdown:
            print(
                "<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```"
            )
        print("(w1,w2)     | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            fmt = ("(%s,%s)     | %s %s %s (%s)" if args.multi_node
                   or args.sched_addr else "(%02d,%02d)     | %s %s %s (%s)")
            print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
        if args.markdown:
            print("```\n</details>\n")

    if args.multi_node:
        client.shutdown()
        client.close()
示例#29
0
def set_chunk_size(num_bytes: Union[int, float] = 128e6):
    dask.config.set({'array.chunk_size': format_bytes(num_bytes)})
示例#30
0
async def run(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    filterwarnings("ignore",
                   message=".*NVLink.*rmm_pool_size.*",
                   category=UserWarning)

    async with Cluster(*cluster_args, **cluster_kwargs,
                       asynchronous=True) as cluster:
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        # Use the scheduler address with an SSHCluster rather than the cluster
        # object, otherwise we can't shut it down.
        async with Client(scheduler_addr if args.multi_node else cluster,
                          asynchronous=True) as client:
            scheduler_workers = await client.run_on_scheduler(
                get_scheduler_workers)

            await client.run(setup_memory_pool,
                             disable_pool=args.disable_rmm_pool)
            # Create an RMM pool on the scheduler due to occasional deserialization
            # of CUDA objects. May cause issues with InfiniBand otherwise.
            await client.run_on_scheduler(setup_memory_pool,
                                          1e9,
                                          disable_pool=args.disable_rmm_pool)

            took_list = []
            for i in range(args.runs):
                res = await _run(client, args)
                took_list.append((res["took"], res["npartitions"]))
                size = res["shape"]
                chunksize = res["chunksize"]

            # Collect, aggregate, and print peer-to-peer bandwidths
            incoming_logs = await client.run(
                lambda dask_worker: dask_worker.incoming_transfer_log)
            bandwidths = defaultdict(list)
            total_nbytes = defaultdict(list)
            for k, L in incoming_logs.items():
                for d in L:
                    if d["total"] >= args.ignore_size:
                        bandwidths[k, d["who"]].append(d["bandwidth"])
                        total_nbytes[k, d["who"]].append(d["total"])

            bandwidths = {(
                scheduler_workers[w1].name,
                scheduler_workers[w2].name,
            ): [
                "%s/s" % format_bytes(x)
                for x in np.quantile(v, [0.25, 0.50, 0.75])
            ]
                          for (w1, w2), v in bandwidths.items()}
            total_nbytes = {(
                scheduler_workers[w1].name,
                scheduler_workers[w2].name,
            ): format_bytes(sum(nb))
                            for (w1, w2), nb in total_nbytes.items()}

            print("Roundtrip benchmark")
            print("--------------------------")
            print(f"Operation          | {args.operation}")
            print(f"User size          | {args.size}")
            print(f"User second size   | {args.second_size}")
            print(f"User chunk-size    | {args.chunk_size}")
            print(f"Compute shape      | {size}")
            print(f"Compute chunk-size | {chunksize}")
            print(f"Ignore-size        | {format_bytes(args.ignore_size)}")
            print(f"Protocol           | {args.protocol}")
            print(f"Device(s)          | {args.devs}")
            print(f"Worker Thread(s)   | {args.threads_per_worker}")
            print("==========================")
            print("Wall-clock         | npartitions")
            print("--------------------------")
            for (took, npartitions) in took_list:
                t = format_time(took)
                t += " " * (11 - len(t))
                print(f"{t}        | {npartitions}")
            print("==========================")
            print("(w1,w2)            | 25% 50% 75% (total nbytes)")
            print("--------------------------")
            for (d1, d2), bw in sorted(bandwidths.items()):
                fmt = ("(%s,%s)            | %s %s %s (%s)"
                       if args.multi_node or args.sched_addr else
                       "(%02d,%02d)            | %s %s %s (%s)")
                print(fmt %
                      (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))

            # An SSHCluster will not automatically shut down, we have to
            # ensure it does.
            if args.multi_node:
                await client.shutdown()