Example #1
0
def main(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    if args.sched_addr:
        client = Client(args.sched_addr)
    else:
        filterwarnings("ignore",
                       message=".*NVLink.*rmm_pool_size.*",
                       category=UserWarning)

        cluster = Cluster(*cluster_args, **cluster_kwargs)
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        client = Client(scheduler_addr if args.multi_node else cluster)

    if args.type == "gpu":
        client.run(
            setup_memory_pool,
            pool_size=args.rmm_pool_size,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )
        # Create an RMM pool on the scheduler due to occasional deserialization
        # of CUDA objects. May cause issues with InfiniBand otherwise.
        client.run_on_scheduler(
            setup_memory_pool,
            pool_size=1e9,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )

    scheduler_workers = client.run_on_scheduler(get_scheduler_workers)
    n_workers = len(scheduler_workers)
    client.wait_for_workers(n_workers)

    if args.all_to_all:
        all_to_all(client)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, n_workers, write_profile=None))
    took_list.append(
        run(client, args, n_workers,
            write_profile=args.profile))  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(
        lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {(scheduler_workers[w1].name, scheduler_workers[w2].name): [
        "%s/s" % format_bytes(x)
        for x in numpy.quantile(v, [0.25, 0.50, 0.75])
    ]
                  for (w1, w2), v in bandwidths.items()}
    total_nbytes = {(
        scheduler_workers[w1].name,
        scheduler_workers[w2].name,
    ): format_bytes(sum(nb))
                    for (w1, w2), nb in total_nbytes.items()}

    t_runs = numpy.empty(len(took_list))
    if args.markdown:
        print("```")
    print("Shuffle benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"partition-size | {format_bytes(args.partition_size)}")
    print(f"in-parts       | {args.in_parts}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    if args.device_memory_limit:
        print(f"memory-limit   | {format_bytes(args.device_memory_limit)}")
    print(f"rmm-pool       | {(not args.disable_rmm_pool)}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for idx, (data_processed, took) in enumerate(took_list):
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
        t_runs[idx] = float(format_bytes(throughput).split(" ")[0])
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.plot is not None:
        plot_benchmark(t_runs, args.plot, historical=True)

    if args.backend == "dask":
        if args.markdown:
            print(
                "<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```"
            )
        print("(w1,w2)        | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            fmt = ("(%s,%s)        | %s %s %s (%s)" if args.multi_node or
                   args.sched_addr else "(%02d,%02d)        | %s %s %s (%s)")
            print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
        if args.markdown:
            print("```\n</details>\n")

    if args.benchmark_json:
        bandwidths_json = {
            "bandwidth_({d1},{d2})_{i}" if args.multi_node or args.sched_addr
            else "(%02d,%02d)_%s" % (d1, d2, i): parse_bytes(v.rstrip("/s"))
            for (d1, d2), bw in sorted(bandwidths.items()) for i, v in zip(
                ["25%", "50%", "75%", "total_nbytes"],
                [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
            )
        }

        with open(args.benchmark_json, "a") as fp:
            for data_processed, took in took_list:
                fp.write(
                    dumps(
                        dict(
                            {
                                "backend": args.backend,
                                "partition_size": args.partition_size,
                                "in_parts": args.in_parts,
                                "protocol": args.protocol,
                                "devs": args.devs,
                                "device_memory_limit":
                                args.device_memory_limit,
                                "rmm_pool": not args.disable_rmm_pool,
                                "tcp": args.enable_tcp_over_ucx,
                                "ib": args.enable_infiniband,
                                "nvlink": args.enable_nvlink,
                                "data_processed": data_processed,
                                "wall_clock": took,
                                "throughput": data_processed / took,
                            },
                            **bandwidths_json,
                        )) + "\n")

    if args.multi_node:
        client.shutdown()
        client.close()
def main(args):
    cluster_options = get_cluster_options(args)
    Cluster = cluster_options["class"]
    cluster_args = cluster_options["args"]
    cluster_kwargs = cluster_options["kwargs"]
    scheduler_addr = cluster_options["scheduler_addr"]

    if args.sched_addr:
        client = Client(args.sched_addr)
    else:
        filterwarnings("ignore",
                       message=".*NVLink.*rmm_pool_size.*",
                       category=UserWarning)

        cluster = Cluster(*cluster_args, **cluster_kwargs)
        if args.multi_node:
            import time

            # Allow some time for workers to start and connect to scheduler
            # TODO: make this a command-line argument?
            time.sleep(15)

        client = Client(scheduler_addr if args.multi_node else cluster)

    if args.type == "gpu":
        client.run(
            setup_memory_pool,
            pool_size=args.rmm_pool_size,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )
        # Create an RMM pool on the scheduler due to occasional deserialization
        # of CUDA objects. May cause issues with InfiniBand otherwise.
        client.run_on_scheduler(
            setup_memory_pool,
            pool_size=1e9,
            disable_pool=args.disable_rmm_pool,
            log_directory=args.rmm_log_directory,
        )

    scheduler_workers = client.run_on_scheduler(get_scheduler_workers)
    n_workers = len(scheduler_workers)
    client.wait_for_workers(n_workers)

    # Allow the number of chunks to vary between
    # the "base" and "other" DataFrames
    args.base_chunks = args.base_chunks or n_workers
    args.other_chunks = args.other_chunks or n_workers

    if args.all_to_all:
        all_to_all(client)

    took_list = []
    for _ in range(args.runs - 1):
        took_list.append(run(client, args, n_workers, write_profile=None))
    took_list.append(
        run(client, args, n_workers,
            write_profile=args.profile))  # Only profiling the last run

    # Collect, aggregate, and print peer-to-peer bandwidths
    incoming_logs = client.run(
        lambda dask_worker: dask_worker.incoming_transfer_log)
    bandwidths = defaultdict(list)
    total_nbytes = defaultdict(list)
    for k, L in incoming_logs.items():
        for d in L:
            if d["total"] >= args.ignore_size:
                bandwidths[k, d["who"]].append(d["bandwidth"])
                total_nbytes[k, d["who"]].append(d["total"])
    bandwidths = {(scheduler_workers[w1].name, scheduler_workers[w2].name): [
        "%s/s" % format_bytes(x)
        for x in numpy.quantile(v, [0.25, 0.50, 0.75])
    ]
                  for (w1, w2), v in bandwidths.items()}
    total_nbytes = {(
        scheduler_workers[w1].name,
        scheduler_workers[w2].name,
    ): format_bytes(sum(nb))
                    for (w1, w2), nb in total_nbytes.items()}

    broadcast = (False if args.shuffle_join else
                 (True if args.broadcast_join else "default"))

    t_runs = numpy.empty(len(took_list))
    if args.markdown:
        print("```")
    print("Merge benchmark")
    print("-------------------------------")
    print(f"backend        | {args.backend}")
    print(f"merge type     | {args.type}")
    print(f"rows-per-chunk | {args.chunk_size}")
    print(f"base-chunks    | {args.base_chunks}")
    print(f"other-chunks   | {args.other_chunks}")
    print(f"broadcast      | {broadcast}")
    print(f"protocol       | {args.protocol}")
    print(f"device(s)      | {args.devs}")
    print(f"rmm-pool       | {(not args.disable_rmm_pool)}")
    print(f"frac-match     | {args.frac_match}")
    if args.protocol == "ucx":
        print(f"tcp            | {args.enable_tcp_over_ucx}")
        print(f"ib             | {args.enable_infiniband}")
        print(f"nvlink         | {args.enable_nvlink}")
    print(f"data-processed | {format_bytes(took_list[0][0])}")
    print("===============================")
    print("Wall-clock     | Throughput")
    print("-------------------------------")
    for idx, (data_processed, took) in enumerate(took_list):
        throughput = int(data_processed / took)
        m = format_time(took)
        m += " " * (15 - len(m))
        print(f"{m}| {format_bytes(throughput)}/s")
        t_runs[idx] = float(format_bytes(throughput).split(" ")[0])
    print("===============================")
    if args.markdown:
        print("\n```")

    if args.plot is not None:
        plot_benchmark(t_runs, args.plot, historical=True)

    if args.backend == "dask":
        if args.markdown:
            print(
                "<details>\n<summary>Worker-Worker Transfer Rates</summary>\n\n```"
            )
        print("(w1,w2)     | 25% 50% 75% (total nbytes)")
        print("-------------------------------")
        for (d1, d2), bw in sorted(bandwidths.items()):
            fmt = ("(%s,%s)     | %s %s %s (%s)" if args.multi_node
                   or args.sched_addr else "(%02d,%02d)     | %s %s %s (%s)")
            print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
        if args.markdown:
            print("```\n</details>\n")

    if args.multi_node:
        client.shutdown()
        client.close()