def _test_peer_communication_tag(queue, rank, msg_size): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) queue.put((rank, worker.get_address())) right_rank, right_address = queue.get() left_rank, left_address = queue.get() right_ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, right_address, endpoint_error_handling=True, ) left_ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, left_address, endpoint_error_handling=True, ) recv_msg = bytearray(msg_size) if rank == 0: send_msg = bytes(os.urandom(msg_size)) blocking_send(worker, right_ep, send_msg, right_rank) blocking_recv(worker, left_ep, recv_msg, rank) assert send_msg == recv_msg else: blocking_recv(worker, left_ep, recv_msg, rank) blocking_send(worker, right_ep, recv_msg, right_rank)
def _server_probe(queue, transfer_api): """Server that probes and receives message after client disconnected. Note that since it is illegal to call progress() in callback functions, we keep a reference to the endpoint after the listener callback has terminated, this way we can progress even after Python blocking calls. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback ep = [None] def _listener_handler(conn_request): ep[0] = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port), while ep[0] is None: worker.progress() ep = ep[0] # Ensure wireup and inform client before it can disconnect if transfer_api == "am": wireup = blocking_am_recv(worker, ep) else: wireup = bytearray(len(WireupMessage)) blocking_recv(worker, ep, wireup) queue.put("wireup completed") # Ensure client has disconnected -- endpoint is not alive anymore while ep.is_alive() is True: worker.progress() # Probe/receive message even after the remote endpoint has disconnected if transfer_api == "am": while ep.am_probe() is False: worker.progress() received = blocking_am_recv(worker, ep) else: while worker.tag_probe(0) is False: worker.progress() received = bytearray(len(DataMessage)) blocking_recv(worker, ep, received) assert wireup == WireupMessage assert received == DataMessage
def _echo_client(msg_size, port, endpoint_error_handling): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create( worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, ) send_msg = bytes(os.urandom(msg_size)) recv_msg = bytearray(msg_size) blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) assert send_msg == recv_msg
def _test_peer_communication_rma(queue, rank, msg_size): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.RMA, ucx_api.Feature.TAG)) worker = ucx_api.UCXWorker(ctx) self_address = worker.get_address() mem_handle = ctx.alloc(msg_size) self_base = mem_handle.address self_prkey = mem_handle.pack_rkey() self_ep, self_mem = _rma_setup(worker, self_address, self_prkey, self_base, msg_size) send_msg = bytes(repeat(rank, msg_size)) if not self_mem.put_nbi(send_msg): blocking_flush(self_ep) queue.put((rank, self_address, self_prkey, self_base)) right_rank, right_address, right_prkey, right_base = queue.get() left_rank, left_address, left_prkey, left_base = queue.get() right_ep, right_mem = _rma_setup(worker, right_address, right_prkey, right_base, msg_size) right_msg = bytearray(msg_size) right_mem.get_nbi(right_msg) left_ep, left_mem = _rma_setup(worker, left_address, left_prkey, left_base, msg_size) left_msg = bytearray(msg_size) left_mem.get_nbi(left_msg) blocking_flush(worker) assert left_msg == bytes(repeat(left_rank, msg_size)) assert right_msg == bytes(repeat(right_rank, msg_size)) # We use the blocking tag send/recv as a barrier implementation recv_msg = bytearray(8) if rank == 0: send_msg = bytes(os.urandom(8)) blocking_send(worker, right_ep, send_msg, right_rank) blocking_recv(worker, left_ep, recv_msg, rank) else: blocking_recv(worker, left_ep, recv_msg, rank) blocking_send(worker, right_ep, recv_msg, right_rank)
def client(queue, port, server_address, args): if args.client_cpu_affinity >= 0: os.sched_setaffinity(0, [args.client_cpu_affinity]) import numpy as np if args.object_type == "numpy": import numpy as xp elif args.object_type == "cupy": import cupy as xp xp.cuda.runtime.setDevice(args.client_dev) else: import cupy as xp import rmm rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=args.rmm_init_pool_size, devices=[args.client_dev], ) xp.cuda.runtime.setDevice(args.client_dev) xp.cuda.set_allocator(rmm.rmm_cupy_allocator) ctx = ucx_api.UCXContext( feature_flags=( ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG, ) ) worker = ucx_api.UCXWorker(ctx) register_am_allocators(args, worker) ep = ucx_api.UCXEndpoint.create( worker, server_address, port, endpoint_error_handling=True, ) send_msg = xp.arange(args.n_bytes, dtype="u1") if args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: wireup_recv = bytearray(len(WireupMessage)) blocking_send(worker, ep, WireupMessage) blocking_recv(worker, ep, wireup_recv) op_lock = Lock() finished = [0] outstanding = [0] def maybe_progress(): while outstanding[0] >= args.max_outstanding: worker.progress() def op_started(): with op_lock: outstanding[0] += 1 def op_completed(): with op_lock: outstanding[0] -= 1 finished[0] += 1 if args.cuda_profile: xp.cuda.profiler.start() times = [] for i in range(args.n_iter + args.n_warmup_iter): start = clock() if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: if not args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.delay_progress: maybe_progress() non_blocking_send(worker, ep, send_msg, op_started, op_completed) maybe_progress() non_blocking_recv(worker, ep, recv_msg, op_started, op_completed) else: blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) stop = clock() if i >= args.n_warmup_iter: times.append(stop - start) if args.delay_progress: while finished[0] != 2 * (args.n_iter + args.n_warmup_iter): worker.progress() if args.cuda_profile: xp.cuda.profiler.stop() assert len(times) == args.n_iter bw_avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times)) bw_med = format_bytes(2 * args.n_bytes / np.median(times)) lat_avg = int(sum(times) * 1e9 / (2 * args.n_iter)) lat_med = int(np.median(times) * 1e9 / 2) delay_progress_str = ( f"True ({args.max_outstanding})" if args.delay_progress is True else "False" ) print("Roundtrip benchmark") print_separator(separator="=") print_key_value(key="Iterations", value=f"{args.n_iter}") print_key_value(key="Bytes", value=f"{format_bytes(args.n_bytes)}") print_key_value(key="Object type", value=f"{args.object_type}") print_key_value(key="Reuse allocation", value=f"{args.reuse_alloc}") print_key_value(key="Transfer API", value=f"{'AM' if args.enable_am else 'TAG'}") print_key_value(key="Delay progress", value=f"{delay_progress_str}") print_key_value(key="UCX_TLS", value=f"{ucp.get_config()['TLS']}") print_key_value(key="UCX_NET_DEVICES", value=f"{ucp.get_config()['NET_DEVICES']}") print_separator(separator="=") if args.object_type == "numpy": print_key_value(key="Device(s)", value="CPU-only") s_aff = ( args.server_cpu_affinity if args.server_cpu_affinity >= 0 else "affinity not set" ) c_aff = ( args.client_cpu_affinity if args.client_cpu_affinity >= 0 else "affinity not set" ) print_key_value(key="Server CPU", value=f"{s_aff}") print_key_value(key="Client CPU", value=f"{c_aff}") else: print_key_value(key="Device(s)", value=f"{args.server_dev}, {args.client_dev}") print_separator(separator="=") print_key_value("Bandwidth (average)", value=f"{bw_avg}/s") print_key_value("Bandwidth (median)", value=f"{bw_med}/s") print_key_value("Latency (average)", value=f"{lat_avg} ns") print_key_value("Latency (median)", value=f"{lat_med} ns") if not args.no_detailed_report: print_separator(separator="=") print_key_value(key="Iterations", value="Bandwidth, Latency") print_separator(separator="-") for i, t in enumerate(times): ts = format_bytes(2 * args.n_bytes / t) lat = int(t * 1e9 / 2) print_key_value(key=i, value=f"{ts}/s, {lat}ns")
def client(queue, port, server_address, args): if args.client_cpu_affinity >= 0: os.sched_setaffinity(0, [args.client_cpu_affinity]) import numpy as np if args.object_type == "numpy": import numpy as xp elif args.object_type == "cupy": import cupy as xp xp.cuda.runtime.setDevice(args.client_dev) else: import cupy as xp import rmm rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=args.rmm_init_pool_size, devices=[args.client_dev], ) xp.cuda.runtime.setDevice(args.client_dev) xp.cuda.set_allocator(rmm.rmm_cupy_allocator) ctx = ucx_api.UCXContext(feature_flags=( ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) register_am_allocators(args, worker) ep = ucx_api.UCXEndpoint.create( worker, server_address, port, endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0), ) send_msg = xp.arange(args.n_bytes, dtype="u1") if args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: wireup_recv = bytearray(len(WireupMessage)) blocking_send(worker, ep, WireupMessage) blocking_recv(worker, ep, wireup_recv) op_lock = Lock() finished = [0] outstanding = [0] def maybe_progress(): while outstanding[0] >= args.max_outstanding: worker.progress() def op_started(): with op_lock: outstanding[0] += 1 def op_completed(): with op_lock: outstanding[0] -= 1 finished[0] += 1 if args.cuda_profile: xp.cuda.profiler.start() times = [] for i in range(args.n_iter): start = clock() if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: if not args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.delay_progress: maybe_progress() non_blocking_send(worker, ep, send_msg, op_started, op_completed) maybe_progress() non_blocking_recv(worker, ep, recv_msg, op_started, op_completed) else: blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) stop = clock() times.append(stop - start) if args.delay_progress: while finished[0] != 2 * args.n_iter: worker.progress() if args.cuda_profile: xp.cuda.profiler.stop() assert len(times) == args.n_iter delay_progress_str = (f"True ({args.max_outstanding})" if args.delay_progress is True else "False") print("Roundtrip benchmark") print("--------------------------") print(f"n_iter | {args.n_iter}") print(f"n_bytes | {format_bytes(args.n_bytes)}") print(f"object | {args.object_type}") print(f"reuse alloc | {args.reuse_alloc}") print(f"transfer API | {'AM' if args.enable_am else 'TAG'}") print(f"delay progress | {delay_progress_str}") print(f"UCX_TLS | {ucp.get_config()['TLS']}") print(f"UCX_NET_DEVICES | {ucp.get_config()['NET_DEVICES']}") print("==========================") if args.object_type == "numpy": print("Device(s) | CPU-only") s_aff = (args.server_cpu_affinity if args.server_cpu_affinity >= 0 else "affinity not set") c_aff = (args.client_cpu_affinity if args.client_cpu_affinity >= 0 else "affinity not set") print(f"Server CPU | {s_aff}") print(f"Client CPU | {c_aff}") else: print(f"Device(s) | {args.server_dev}, {args.client_dev}") avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times)) med = format_bytes(2 * args.n_bytes / np.median(times)) print(f"Average | {avg}/s") print(f"Median | {med}/s") if not args.no_detailed_report: print("--------------------------") print("Iterations") print("--------------------------") for i, t in enumerate(times): ts = format_bytes(2 * args.n_bytes / t) ts = (" " * (9 - len(ts))) + ts print("%03d |%s/s" % (i, ts))