Exemple #1
0
def test_format_time():
    assert format_time(1.4) == "1.40 s"
    assert format_time(10.4) == "10.40 s"
    assert format_time(100.4) == "100.40 s"
    assert format_time(1000.4) == "16m 40s"
    assert format_time(10000.4) == "2hr 46m"
    assert format_time(1234.567) == "20m 34s"
    assert format_time(12345.67) == "3hr 25m"
    assert format_time(123456.78) == "34hr 17m"
    assert format_time(1234567.8) == "14d 6hr"
Exemple #2
0
    def traverse(state, start, stop, height):
        if not state["count"]:
            return
        starts.append(start)
        stops.append(stop)
        heights.append(height)
        width = stop - start
        widths.append(width)
        states.append(state)
        times.append(format_time(state["count"] * profile_interval))

        desc = state["description"]
        filenames.append(desc["filename"])
        lines.append(desc["line"])
        line_numbers.append(desc["line_number"])
        names.append(desc["name"])

        try:
            fn = desc["filename"]
        except IndexError:  # pragma: no cover
            colors.append("gray")
        else:
            if fn == "<low-level>":  # pragma: no cover
                colors.append("lightgray")
            else:
                colors.append(color_of(fn))

        delta = (stop - start) / state["count"]

        x = start

        for _, child in state["children"].items():
            width = child["count"] * delta
            traverse(child, x, x + width, height + 1)
            x += width
Exemple #3
0
def rectangles(msgs, workers=None, start_boundary=0):
    if workers is None:
        workers = {}

    L_start = []
    L_duration = []
    L_duration_text = []
    L_key = []
    L_name = []
    L_color = []
    L_alpha = []
    L_worker = []
    L_worker_thread = []
    L_y = []

    for msg in msgs:
        key = msg["key"]
        name = key_split(key)
        startstops = msg.get("startstops", [])
        try:
            worker_thread = "%s-%d" % (msg["worker"], msg["thread"])
        except Exception:
            continue
            logger.warning("Message contained bad information: %s", msg, exc_info=True)
            worker_thread = ""

        if worker_thread not in workers:
            workers[worker_thread] = len(workers) / 2

        for startstop in startstops:
            if startstop["start"] < start_boundary:
                continue
            color = colors[startstop["action"]]
            if type(color) is not str:
                color = color(msg)

            L_start.append((startstop["start"] + startstop["stop"]) / 2 * 1000)
            L_duration.append(1000 * (startstop["stop"] - startstop["start"]))
            L_duration_text.append(format_time(startstop["stop"] - startstop["start"]))
            L_key.append(key)
            L_name.append(prefix[startstop["action"]] + name)
            L_color.append(color)
            L_alpha.append(alphas[startstop["action"]])
            L_worker.append(msg["worker"])
            L_worker_thread.append(worker_thread)
            L_y.append(workers[worker_thread])

    return {
        "start": L_start,
        "duration": L_duration,
        "duration_text": L_duration_text,
        "key": L_key,
        "name": L_name,
        "color": L_color,
        "alpha": L_alpha,
        "worker": L_worker,
        "worker_thread": L_worker_thread,
        "y": L_y,
    }
Exemple #4
0
    def _draw_bar(self, frac, elapsed):
        from dask.utils import format_time

        bar = "#" * int(self._width * frac)
        percent = int(100 * frac)
        elapsed = format_time(elapsed)
        msg = "\r[{0:<{1}}] | {2}% Completed | {3}".format(
            bar, self._width, percent, elapsed
        )
        with contextlib.suppress(ValueError):
            if self._file is not None:
                self._file.write(msg)
                self._file.flush()
Exemple #5
0
    def update(self):
        with log_errors():
            outgoing = self.worker.outgoing_transfer_log
            n = self.worker.outgoing_count - self.last_outgoing
            outgoing = [outgoing[-i].copy() for i in range(1, n + 1)]
            self.last_outgoing = self.worker.outgoing_count

            incoming = self.worker.incoming_transfer_log
            n = self.worker.incoming_count - self.last_incoming
            incoming = [incoming[-i].copy() for i in range(1, n + 1)]
            self.last_incoming = self.worker.incoming_count

            for [msgs, source] in [
                [incoming, self.incoming],
                [outgoing, self.outgoing],
            ]:

                for msg in msgs:
                    if "compressed" in msg:
                        del msg["compressed"]
                    del msg["keys"]

                    bandwidth = msg["total"] / (msg["duration"] or 0.5)
                    bw = max(min(bandwidth / 500e6, 1), 0.3)
                    msg["alpha"] = bw
                    try:
                        msg["y"] = self.who[msg["who"]]
                    except KeyError:
                        self.who[msg["who"]] = len(self.who)
                        msg["y"] = self.who[msg["who"]]

                    msg["hover"] = "{} / {} = {}/s".format(
                        format_bytes(msg["total"]),
                        format_time(msg["duration"]),
                        format_bytes(msg["total"] / msg["duration"]),
                    )

                    for k in ["middle", "duration", "start", "stop"]:
                        msg[k] = msg[k] * 1000

                if msgs:
                    msgs = transpose(msgs)
                    if (len(source.data["stop"]) and min(msgs["start"]) >
                            source.data["stop"][-1] + 10000):
                        source.data.update(msgs)
                    else:
                        source.stream(msgs, rollover=10000)
def main(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    if args.sched_addr:
        client = Client(args.sched_addr)
    else:
        filterwarnings("ignore",
                       message=".*NVLink.*rmm_pool_size.*",
                       category=UserWarning)

        cluster = Cluster(*cluster_args, **cluster_kwargs)
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        client = Client(scheduler_addr if args.multi_node else cluster)

    if args.type == "gpu":
        client.run(
            setup_memory_pool,
            pool_size=args.rmm_pool_size,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )
        # Create an RMM pool on the scheduler due to occasional deserialization
        # of CUDA objects. May cause issues with InfiniBand otherwise.
        client.run_on_scheduler(
            setup_memory_pool,
            pool_size=1e9,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )

    scheduler_workers = client.run_on_scheduler(get_scheduler_workers)
    n_workers = len(scheduler_workers)
    client.wait_for_workers(n_workers)

    if args.all_to_all:
        all_to_all(client)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, n_workers, write_profile=None))
    took_list.append(
        run(client, args, n_workers,
            write_profile=args.profile))  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(
        lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {(scheduler_workers[w1].name, scheduler_workers[w2].name): [
        "%s/s" % format_bytes(x)
        for x in numpy.quantile(v, [0.25, 0.50, 0.75])
    ]
                  for (w1, w2), v in bandwidths.items()}
    total_nbytes = {(
        scheduler_workers[w1].name,
        scheduler_workers[w2].name,
    ): format_bytes(sum(nb))
                    for (w1, w2), nb in total_nbytes.items()}

    t_runs = numpy.empty(len(took_list))
    if args.markdown:
        print("```")
    print("Shuffle benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"partition-size | {format_bytes(args.partition_size)}")
    print(f"in-parts       | {args.in_parts}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    if args.device_memory_limit:
        print(f"memory-limit   | {format_bytes(args.device_memory_limit)}")
    print(f"rmm-pool       | {(not args.disable_rmm_pool)}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for idx, (data_processed, took) in enumerate(took_list):
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
        t_runs[idx] = float(format_bytes(throughput).split(" ")[0])
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.plot is not None:
        plot_benchmark(t_runs, args.plot, historical=True)

    if args.backend == "dask":
        if args.markdown:
            print(
                "<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```"
            )
        print("(w1,w2)        | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            fmt = ("(%s,%s)        | %s %s %s (%s)" if args.multi_node or
                   args.sched_addr else "(%02d,%02d)        | %s %s %s (%s)")
            print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
        if args.markdown:
            print("```\n</details>\n")

    if args.benchmark_json:
        bandwidths_json = {
            "bandwidth_({d1},{d2})_{i}" if args.multi_node or args.sched_addr
            else "(%02d,%02d)_%s" % (d1, d2, i): parse_bytes(v.rstrip("/s"))
            for (d1, d2), bw in sorted(bandwidths.items()) for i, v in zip(
                ["25%", "50%", "75%", "total_nbytes"],
                [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
            )
        }

        with open(args.benchmark_json, "a") as fp:
            for data_processed, took in took_list:
                fp.write(
                    dumps(
                        dict(
                            {
                                "backend": args.backend,
                                "partition_size": args.partition_size,
                                "in_parts": args.in_parts,
                                "protocol": args.protocol,
                                "devs": args.devs,
                                "device_memory_limit":
                                args.device_memory_limit,
                                "rmm_pool": not args.disable_rmm_pool,
                                "tcp": args.enable_tcp_over_ucx,
                                "ib": args.enable_infiniband,
                                "nvlink": args.enable_nvlink,
                                "data_processed": data_processed,
                                "wall_clock": took,
                                "throughput": data_processed / took,
                            },
                            **bandwidths_json,
                        )) + "\n")

    if args.multi_node:
        client.shutdown()
        client.close()
Exemple #7
0
async def run(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    filterwarnings("ignore",
                   message=".*NVLink.*rmm_pool_size.*",
                   category=UserWarning)

    async with Cluster(*cluster_args, **cluster_kwargs,
                       asynchronous=True) as cluster:
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        # Use the scheduler address with an SSHCluster rather than the cluster
        # object, otherwise we can't shut it down.
        async with Client(scheduler_addr if args.multi_node else cluster,
                          asynchronous=True) as client:
            scheduler_workers = await client.run_on_scheduler(
                get_scheduler_workers)

            await client.run(setup_memory_pool,
                             disable_pool=args.disable_rmm_pool)
            # Create an RMM pool on the scheduler due to occasional deserialization
            # of CUDA objects. May cause issues with InfiniBand otherwise.
            await client.run_on_scheduler(setup_memory_pool,
                                          1e9,
                                          disable_pool=args.disable_rmm_pool)

            took_list = []
            for i in range(args.runs):
                res = await _run(client, args)
                took_list.append((res["took"], res["npartitions"]))
                size = res["shape"]
                chunksize = res["chunksize"]

            # Collect, aggregate, and print peer-to-peer bandwidths
            incoming_logs = await client.run(
                lambda dask_worker: dask_worker.incoming_transfer_log)
            bandwidths = defaultdict(list)
            total_nbytes = defaultdict(list)
            for k, L in incoming_logs.items():
                for d in L:
                    if d["total"] >= args.ignore_size:
                        bandwidths[k, d["who"]].append(d["bandwidth"])
                        total_nbytes[k, d["who"]].append(d["total"])

            bandwidths = {(
                scheduler_workers[w1].name,
                scheduler_workers[w2].name,
            ): [
                "%s/s" % format_bytes(x)
                for x in np.quantile(v, [0.25, 0.50, 0.75])
            ]
                          for (w1, w2), v in bandwidths.items()}
            total_nbytes = {(
                scheduler_workers[w1].name,
                scheduler_workers[w2].name,
            ): format_bytes(sum(nb))
                            for (w1, w2), nb in total_nbytes.items()}

            print("Roundtrip benchmark")
            print("--------------------------")
            print(f"Operation          | {args.operation}")
            print(f"User size          | {args.size}")
            print(f"User second size   | {args.second_size}")
            print(f"User chunk-size    | {args.chunk_size}")
            print(f"Compute shape      | {size}")
            print(f"Compute chunk-size | {chunksize}")
            print(f"Ignore-size        | {format_bytes(args.ignore_size)}")
            print(f"Protocol           | {args.protocol}")
            print(f"Device(s)          | {args.devs}")
            print(f"Worker Thread(s)   | {args.threads_per_worker}")
            print("==========================")
            print("Wall-clock         | npartitions")
            print("--------------------------")
            for (took, npartitions) in took_list:
                t = format_time(took)
                t += " " * (11 - len(t))
                print(f"{t}        | {npartitions}")
            print("==========================")
            print("(w1,w2)            | 25% 50% 75% (total nbytes)")
            print("--------------------------")
            for (d1, d2), bw in sorted(bandwidths.items()):
                fmt = ("(%s,%s)            | %s %s %s (%s)"
                       if args.multi_node or args.sched_addr else
                       "(%02d,%02d)            | %s %s %s (%s)")
                print(fmt %
                      (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))

            # An SSHCluster will not automatically shut down, we have to
            # ensure it does.
            if args.multi_node:
                await client.shutdown()
def main(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    if args.sched_addr:
        client = Client(args.sched_addr)
    else:
        filterwarnings("ignore",
                       message=".*NVLink.*rmm_pool_size.*",
                       category=UserWarning)

        cluster = Cluster(*cluster_args, **cluster_kwargs)
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        client = Client(scheduler_addr if args.multi_node else cluster)

    if args.type == "gpu":
        client.run(
            setup_memory_pool,
            pool_size=args.rmm_pool_size,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )
        # Create an RMM pool on the scheduler due to occasional deserialization
        # of CUDA objects. May cause issues with InfiniBand otherwise.
        client.run_on_scheduler(
            setup_memory_pool,
            pool_size=1e9,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )

    scheduler_workers = client.run_on_scheduler(get_scheduler_workers)
    n_workers = len(scheduler_workers)
    client.wait_for_workers(n_workers)

    # Allow the number of chunks to vary between
    # the "base" and "other" DataFrames
    args.base_chunks = args.base_chunks or n_workers
    args.other_chunks = args.other_chunks or n_workers

    if args.all_to_all:
        all_to_all(client)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, n_workers, write_profile=None))
    took_list.append(
        run(client, args, n_workers,
            write_profile=args.profile))  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(
        lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {(scheduler_workers[w1].name, scheduler_workers[w2].name): [
        "%s/s" % format_bytes(x)
        for x in numpy.quantile(v, [0.25, 0.50, 0.75])
    ]
                  for (w1, w2), v in bandwidths.items()}
    total_nbytes = {(
        scheduler_workers[w1].name,
        scheduler_workers[w2].name,
    ): format_bytes(sum(nb))
                    for (w1, w2), nb in total_nbytes.items()}

    broadcast = (False if args.shuffle_join else
                 (True if args.broadcast_join else "default"))

    t_runs = numpy.empty(len(took_list))
    if args.markdown:
        print("```")
    print("Merge benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"merge type     | {args.type}")
    print(f"rows-per-chunk | {args.chunk_size}")
    print(f"base-chunks    | {args.base_chunks}")
    print(f"other-chunks   | {args.other_chunks}")
    print(f"broadcast      | {broadcast}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    print(f"rmm-pool       | {(not args.disable_rmm_pool)}")
    print(f"frac-match     | {args.frac_match}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for idx, (data_processed, took) in enumerate(took_list):
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
        t_runs[idx] = float(format_bytes(throughput).split(" ")[0])
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.plot is not None:
        plot_benchmark(t_runs, args.plot, historical=True)

    if args.backend == "dask":
        if args.markdown:
            print(
                "<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```"
            )
        print("(w1,w2)     | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            fmt = ("(%s,%s)     | %s %s %s (%s)" if args.multi_node
                   or args.sched_addr else "(%02d,%02d)     | %s %s %s (%s)")
            print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
        if args.markdown:
            print("```\n</details>\n")

    if args.multi_node:
        client.shutdown()
        client.close()
Exemple #9
0
def main(args):
    # Set up workers on the local machine
    if args.protocol == "tcp":
        cluster = LocalCUDACluster(
            protocol=args.protocol,
            n_workers=args.n_workers,
            CUDA_VISIBLE_DEVICES=args.devs,
        )
    else:
        enable_infiniband = args.enable_infiniband
        enable_nvlink = args.enable_nvlink
        enable_tcp_over_ucx = args.enable_tcp_over_ucx
        cluster = LocalCUDACluster(
            protocol=args.protocol,
            n_workers=args.n_workers,
            CUDA_VISIBLE_DEVICES=args.devs,
            ucx_net_devices="auto",
            enable_tcp_over_ucx=enable_tcp_over_ucx,
            enable_infiniband=enable_infiniband,
            enable_nvlink=enable_nvlink,
        )
        initialize(
            create_cuda_context=True,
            enable_tcp_over_ucx=enable_tcp_over_ucx,
            enable_infiniband=enable_infiniband,
            enable_nvlink=enable_nvlink,
        )
    client = Client(cluster)

    def _worker_setup(initial_pool_size=None):
        import rmm

        rmm.reinitialize(
            pool_allocator=not args.no_rmm_pool,
            devices=0,
            initial_pool_size=initial_pool_size,
        )
        cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)

    client.run(_worker_setup)
    # Create an RMM pool on the scheduler due to occasional deserialization
    # of CUDA objects. May cause issues with InfiniBand otherwise.
    client.run_on_scheduler(_worker_setup, 1e9)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, write_profile=None))
    took_list.append(
        run(client, args, write_profile=args.profile)
    )  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {
        (cluster.scheduler.workers[w1].name, cluster.scheduler.workers[w2].name): [
            "%s/s" % format_bytes(x) for x in numpy.quantile(v, [0.25, 0.50, 0.75])
        ]
        for (w1, w2), v in bandwidths.items()
    }
    total_nbytes = {
        (
            cluster.scheduler.workers[w1].name,
            cluster.scheduler.workers[w2].name,
        ): format_bytes(sum(nb))
        for (w1, w2), nb in total_nbytes.items()
    }

    if args.markdown:
        print("```")
    print("Merge benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"rows-per-chunk | {args.chunk_size}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    print(f"rmm-pool       | {(not args.no_rmm_pool)}")
    print(f"frac-match     | {args.frac_match}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for data_processed, took in took_list:
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.backend == "dask":
        if args.markdown:
            print("<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```")
        print("(w1,w2)     | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            print(
                "(%02d,%02d)     | %s %s %s (%s)"
                % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)])
            )
        if args.markdown:
            print("```\n</details>\n")
def all_ping(client: Client):
    workers = list(client.scheduler_info()["workers"])
    start = time.time()
    client.run(ping, workers)
    stop = time.time()
    print(format_time(stop - start))
async def run(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    filterwarnings("ignore",
                   message=".*NVLink.*rmm_pool_size.*",
                   category=UserWarning)

    async with Cluster(*cluster_args, **cluster_kwargs,
                       asynchronous=True) as cluster:
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        # Use the scheduler address with an SSHCluster rather than the cluster
        # object, otherwise we can't shut it down.
        async with Client(scheduler_addr if args.multi_node else cluster,
                          asynchronous=True) as client:
            scheduler_workers = await client.run_on_scheduler(
                get_scheduler_workers)

            await client.run(
                setup_memory_pool,
                disable_pool=args.disable_rmm_pool,
                log_directory=args.rmm_log_directory,
            )
            # Create an RMM pool on the scheduler due to occasional deserialization
            # of CUDA objects. May cause issues with InfiniBand otherwise.
            await client.run_on_scheduler(
                setup_memory_pool,
                pool_size=1e9,
                disable_pool=args.disable_rmm_pool,
                log_directory=args.rmm_log_directory,
            )

            took_list = []
            for i in range(args.runs):
                took_list.append(await _run(client, args))

            # Collect, aggregate, and print peer-to-peer bandwidths
            incoming_logs = await client.run(
                lambda dask_worker: dask_worker.incoming_transfer_log)
            bandwidths = defaultdict(list)
            total_nbytes = defaultdict(list)
            for k, L in incoming_logs.items():
                for d in L:
                    if d["total"] >= args.ignore_size:
                        bandwidths[k, d["who"]].append(d["bandwidth"])
                        total_nbytes[k, d["who"]].append(d["total"])

            bandwidths = {(
                scheduler_workers[w1].name,
                scheduler_workers[w2].name,
            ): [
                "%s/s" % format_bytes(x)
                for x in np.quantile(v, [0.25, 0.50, 0.75])
            ]
                          for (w1, w2), v in bandwidths.items()}
            total_nbytes = {(
                scheduler_workers[w1].name,
                scheduler_workers[w2].name,
            ): format_bytes(sum(nb))
                            for (w1, w2), nb in total_nbytes.items()}

            print("Roundtrip benchmark")
            print("--------------------------")
            print(f"Size         | {args.size}*{args.size}")
            print(f"Chunk-size   | {args.chunk_size}")
            print(f"Ignore-size  | {format_bytes(args.ignore_size)}")
            print(f"Protocol     | {args.protocol}")
            print(f"Device(s)    | {args.devs}")
            if args.device_memory_limit:
                print(
                    f"memory-limit | {format_bytes(args.device_memory_limit)}")
            print("==========================")
            print("Wall-clock   | npartitions")
            print("--------------------------")
            for (took, npartitions) in took_list:
                t = format_time(took)
                t += " " * (12 - len(t))
                print(f"{t} | {npartitions}")
            print("==========================")
            print("(w1,w2)      | 25% 50% 75% (total nbytes)")
            print("--------------------------")
            for (d1, d2), bw in sorted(bandwidths.items()):
                fmt = ("(%s,%s)      | %s %s %s (%s)" if args.multi_node or
                       args.sched_addr else "(%02d,%02d)      | %s %s %s (%s)")
                print(fmt %
                      (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))

            if args.benchmark_json:
                bandwidths_json = {
                    "bandwidth_({d1},{d2})_{i}" if args.multi_node
                    or args.sched_addr else "(%02d,%02d)_%s" % (d1, d2, i):
                    parse_bytes(v.rstrip("/s"))
                    for (d1, d2), bw in sorted(bandwidths.items())
                    for i, v in zip(
                        ["25%", "50%", "75%", "total_nbytes"],
                        [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
                    )
                }

                with open(args.benchmark_json, "a") as fp:
                    for took, npartitions in took_list:
                        fp.write(
                            dumps(
                                dict(
                                    {
                                        "size": args.size * args.size,
                                        "chunk_size": args.chunk_size,
                                        "ignore_size": args.ignore_size,
                                        "protocol": args.protocol,
                                        "devs": args.devs,
                                        "device_memory_limit":
                                        args.device_memory_limit,
                                        "worker_threads":
                                        args.threads_per_worker,
                                        "rmm_pool": not args.disable_rmm_pool,
                                        "tcp": args.enable_tcp_over_ucx,
                                        "ib": args.enable_infiniband,
                                        "nvlink": args.enable_nvlink,
                                        "wall_clock": took,
                                        "npartitions": npartitions,
                                    },
                                    **bandwidths_json,
                                )) + "\n")

            # An SSHCluster will not automatically shut down, we have to
            # ensure it does.
            if args.multi_node:
                await client.shutdown()
Exemple #12
0
def main(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    cluster = Cluster(*cluster_args, **cluster_kwargs)
    if args.multi_node:
        import time

        # Allow some time for workers to start and connect to scheduler
        # TODO: make this a command-line argument?
        time.sleep(15)

    client = Client(scheduler_addr if args.multi_node else cluster)

    client.run(setup_memory_pool, disable_pool=args.no_rmm_pool)
    # Create an RMM pool on the scheduler due to occasional deserialization
    # of CUDA objects. May cause issues with InfiniBand otherwise.
    client.run_on_scheduler(setup_memory_pool, 1e9, disable_pool=args.no_rmm_pool)

    scheduler_workers = client.run_on_scheduler(get_scheduler_workers)
    n_workers = len(scheduler_workers)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, n_workers, write_profile=None))
    took_list.append(
        run(client, args, n_workers, write_profile=args.profile)
    )  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {
        (scheduler_workers[w1].name, scheduler_workers[w2].name): [
            "%s/s" % format_bytes(x) for x in numpy.quantile(v, [0.25, 0.50, 0.75])
        ]
        for (w1, w2), v in bandwidths.items()
    }
    total_nbytes = {
        (scheduler_workers[w1].name, scheduler_workers[w2].name,): format_bytes(sum(nb))
        for (w1, w2), nb in total_nbytes.items()
    }

    if args.markdown:
        print("```")
    print("Merge benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"rows-per-chunk | {args.chunk_size}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    print(f"rmm-pool       | {(not args.no_rmm_pool)}")
    print(f"frac-match     | {args.frac_match}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for data_processed, took in took_list:
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.backend == "dask":
        if args.markdown:
            print("<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```")
        print("(w1,w2)     | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            fmt = (
                "(%s,%s)     | %s %s %s (%s)"
                if args.multi_node
                else "(%02d,%02d)     | %s %s %s (%s)"
            )
            print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
        if args.markdown:
            print("```\n</details>\n")

    if args.multi_node:
        client.shutdown()
        client.close()