def _server_cancel(queue, transfer_api): """Server that establishes an endpoint to client and immediately closes it, triggering received messages to be canceled on the client. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback ep = [None] def _listener_handler(conn_request): ep[0] = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port) while ep[0] is None: worker.progress() ep[0].close() worker.progress()
def test_feature_flags_mismatch(feature_flag): ctx = ucx_api.UCXContext(feature_flags=(feature_flag, )) worker = ucx_api.UCXWorker(ctx) addr = worker.get_address() ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, addr, endpoint_error_handling=False) msg = Array(bytearray(10)) if feature_flag != ucx_api.Feature.TAG: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`"): ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`"): ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None) if feature_flag != ucx_api.Feature.STREAM: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.STREAM`"): ucx_api.stream_send_nb(ep, msg, msg.nbytes, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.STREAM`"): ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None) if feature_flag != ucx_api.Feature.AM: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.AM`"): ucx_api.am_send_nbx(ep, msg, msg.nbytes, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.AM`"): ucx_api.am_recv_nb(ep, None)
def test_force_requests(): msg_size = 1024 ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size) counter = 0 send_msg = bytes(os.urandom(msg_size)) req = self_mem.put_nb(send_msg, _) while req is None: counter = counter + 1 req = self_mem.put_nb(send_msg, _) # This `if` is here because some combinations of transports, such as # normal desktop PCs, will never have their transports exhausted. So # we have a break to make sure this test still completes if counter > 10000: pytest.xfail("Could not generate a request") blocking_flush(worker) while worker.progress(): pass while self_mem.put_nb(send_msg, _): pass blocking_flush(worker)
def _test_peer_communication_tag(queue, rank, msg_size): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) queue.put((rank, worker.get_address())) right_rank, right_address = queue.get() left_rank, left_address = queue.get() right_ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, right_address, endpoint_error_handling=True, ) left_ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, left_address, endpoint_error_handling=True, ) recv_msg = bytearray(msg_size) if rank == 0: send_msg = bytes(os.urandom(msg_size)) blocking_send(worker, right_ep, send_msg, right_rank) blocking_recv(worker, left_ep, recv_msg, rank) assert send_msg == recv_msg else: blocking_recv(worker, left_ep, recv_msg, rank) blocking_send(worker, right_ep, recv_msg, right_rank)
def test_ucxio_seek_bad(seek_loc, seek_flag): msg_size = 1024 ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) uio = ucx_api.UCXIO(mem.address, msg_size, rkey) send_msg = bytes(os.urandom(msg_size)) uio.write(send_msg) uio.seek(seek_loc, seek_flag) recv_msg = uio.read(len(send_msg)) if seek_loc > 0: expected = b"" else: expected = send_msg assert recv_msg == expected
def _server_probe(queue, transfer_api): """Server that probes and receives message after client disconnected. Note that since it is illegal to call progress() in callback functions, we keep a reference to the endpoint after the listener callback has terminated, this way we can progress even after Python blocking calls. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback ep = [None] def _listener_handler(conn_request): ep[0] = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port), while ep[0] is None: worker.progress() ep = ep[0] # Ensure wireup and inform client before it can disconnect if transfer_api == "am": wireup = blocking_am_recv(worker, ep) else: wireup = bytearray(len(WireupMessage)) blocking_recv(worker, ep, wireup) queue.put("wireup completed") # Ensure client has disconnected -- endpoint is not alive anymore while ep.is_alive() is True: worker.progress() # Probe/receive message even after the remote endpoint has disconnected if transfer_api == "am": while ep.am_probe() is False: worker.progress() received = blocking_am_recv(worker, ep) else: while worker.tag_probe(0) is False: worker.progress() received = bytearray(len(DataMessage)) blocking_recv(worker, ep, received) assert wireup == WireupMessage assert received == DataMessage
def test_pickle_ucx_address(): ctx = ucx_api.UCXContext() worker = ucx_api.UCXWorker(ctx) org_address = worker.get_address() dumped_address = pickle.dumps(org_address) org_address_hash = hash(org_address) org_address = bytes(org_address) new_address = pickle.loads(dumped_address) assert org_address_hash == hash(new_address) assert bytes(org_address) == bytes(new_address)
def _echo_server(get_queue, put_queue, msg_size, datatype): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ data = get_data()[datatype] ctx = ucx_api.UCXContext( config_dict={"RNDV_THRESH": str(RNDV_THRESH)}, feature_flags=(ucx_api.Feature.AM,), ) worker = ucx_api.UCXWorker(ctx) worker.register_am_allocator(data["allocator"], data["memory_type"]) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None def _recv_handle(recv_obj, exception, ep): assert exception is None msg = Array(recv_obj) ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,)) def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) # Wireup ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) # Data ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) put_queue.put(listener.port) while True: worker.progress() try: get_queue.get(block=False, timeout=0.1) except QueueIsEmpty: continue else: break
def _echo_client(msg_size, port, endpoint_error_handling): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create( worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, ) send_msg = bytes(os.urandom(msg_size)) recv_msg = bytearray(msg_size) blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) assert send_msg == recv_msg
def test_rkey_unpack(): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=get_endpoint_error_handling_default(), ) rkey = ep.unpack_rkey(packed_rkey) assert rkey is not None
def test_listener_ip_port(): ctx = ucx_api.UCXContext() worker = ucx_api.UCXWorker(ctx) def _listener_handler(conn_request): pass listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) assert isinstance(listener.ip, str) and listener.ip assert (isinstance(listener.port, int) and listener.port >= 0 and listener.port <= 65535)
def test_flush(): ctx = ucx_api.UCXContext({}) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) req = ep.flush(_) if req is None: info = req.info while info["status"] == "pending": worker.progress() assert info["status"] == "finished"
def _echo_server(get_queue, put_queue, msg_size): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) worker = ucx_api.UCXWorker(ctx) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None def _recv_handle(request, exception, ep, msg): assert exception is None ucx_api.tag_send_nb( ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,) ) def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) msg = Array(bytearray(msg_size)) ucx_api.tag_recv_nb( worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg) ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) put_queue.put(listener.port) while True: worker.progress() try: get_queue.get(block=False, timeout=0.1) except QueueIsEmpty: continue else: break
def _client_cancel(queue, transfer_api): """Client that connects to server and waits for messages to be received, because the server closes without sending anything, the messages will trigger cancelation. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) port = queue.get() ep = ucx_api.UCXEndpoint.create( worker, get_address(), port, endpoint_error_handling=True, ) ret = [None] if transfer_api == "am": ucx_api.am_recv_nb(ep, cb_func=_handler, cb_args=(ret, )) match_msg = ".*am_recv.*" else: msg = Array(bytearray(1)) ucx_api.tag_recv_nb(worker, msg, msg.nbytes, tag=0, cb_func=_handler, cb_args=(ret, ), ep=ep) match_msg = ".*tag_recv_nb.*" while ep.is_alive(): worker.progress() canceled = worker.cancel_inflight_messages() while ret[0] is None: worker.progress() assert canceled == 1 assert isinstance(ret[0], UCXCanceled) assert re.match(match_msg, ret[0].args[0])
def _client(port, server_close_callback): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create( worker, ucx_api.get_address(), port, endpoint_error_handling=True, ) if server_close_callback is True: ep.close() worker.progress() else: closed = [False] ep.set_close_callback(functools.partial(_close_callback, closed)) while closed[0] is False: worker.progress()
def test_ucxio(msg_size): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) uio = ucx_api.UCXIO(mem.address, msg_size, rkey) send_msg = bytes(os.urandom(msg_size)) uio.write(send_msg) uio.seek(0) recv_msg = uio.read(msg_size) assert send_msg == recv_msg
def _test_peer_communication_rma(queue, rank, msg_size): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.RMA, ucx_api.Feature.TAG)) worker = ucx_api.UCXWorker(ctx) self_address = worker.get_address() mem_handle = ctx.alloc(msg_size) self_base = mem_handle.address self_prkey = mem_handle.pack_rkey() self_ep, self_mem = _rma_setup(worker, self_address, self_prkey, self_base, msg_size) send_msg = bytes(repeat(rank, msg_size)) if not self_mem.put_nbi(send_msg): blocking_flush(self_ep) queue.put((rank, self_address, self_prkey, self_base)) right_rank, right_address, right_prkey, right_base = queue.get() left_rank, left_address, left_prkey, left_base = queue.get() right_ep, right_mem = _rma_setup(worker, right_address, right_prkey, right_base, msg_size) right_msg = bytearray(msg_size) right_mem.get_nbi(right_msg) left_ep, left_mem = _rma_setup(worker, left_address, left_prkey, left_base, msg_size) left_msg = bytearray(msg_size) left_mem.get_nbi(left_msg) blocking_flush(worker) assert left_msg == bytes(repeat(left_rank, msg_size)) assert right_msg == bytes(repeat(right_rank, msg_size)) # We use the blocking tag send/recv as a barrier implementation recv_msg = bytearray(8) if rank == 0: send_msg = bytes(os.urandom(8)) blocking_send(worker, right_ep, send_msg, right_rank) blocking_recv(worker, left_ep, recv_msg, rank) else: blocking_recv(worker, left_ep, recv_msg, rank) blocking_send(worker, right_ep, recv_msg, right_rank)
def test_implicit(msg_size): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size) send_msg = bytes(os.urandom(msg_size)) if not self_mem.put_nbi(send_msg): blocking_flush(ep) recv_msg = bytearray(len(send_msg)) if not self_mem.get_nbi(recv_msg): blocking_flush(ep) assert send_msg == recv_msg
def _client_probe(queue, transfer_api): feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) port = queue.get() ep = ucx_api.UCXEndpoint.create( worker, get_address(), port, endpoint_error_handling=True, ) _send = blocking_am_send if transfer_api == "am" else blocking_send _send(worker, ep, WireupMessage) _send(worker, ep, DataMessage) # Wait for wireup before disconnecting assert queue.get() == "wireup completed"
def test_ucxio_seek_good(seek_data): seek_flag, seek_dest = seek_data ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024) msg_size = mem.length packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) uio = ucx_api.UCXIO(mem.address, msg_size, rkey) send_msg = bytes(os.urandom(msg_size)) uio.write(send_msg) uio.seek(seek_dest, seek_flag) recv_msg = uio.read(4) assert recv_msg == send_msg[seek_dest:seek_dest + 4]
def _echo_client(msg_size, datatype, port, endpoint_error_handling): data = get_data()[datatype] ctx = ucx_api.UCXContext( config_dict={"RNDV_THRESH": str(RNDV_THRESH)}, feature_flags=(ucx_api.Feature.AM, ), ) worker = ucx_api.UCXWorker(ctx) worker.register_am_allocator(data["allocator"], data["memory_type"]) ep = ucx_api.UCXEndpoint.create( worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, ) # The wireup message is sent to ensure endpoints are connected, otherwise # UCX may not perform any rendezvous transfers. send_wireup = bytearray(b"wireup") send_data = data["generator"](msg_size) blocking_am_send(worker, ep, send_wireup) blocking_am_send(worker, ep, send_data) recv_wireup = blocking_am_recv(worker, ep) recv_data = blocking_am_recv(worker, ep) # Cast recv_wireup to bytearray when using NumPy as a host allocator, # this ensures the assertion below is correct if datatype == "numpy": recv_wireup = bytearray(recv_wireup) assert bytearray(recv_wireup) == send_wireup if data["memory_type"] == "cuda" and send_data.nbytes < RNDV_THRESH: # Eager messages are always received on the host, if no host # allocator is registered UCX-Py defaults to `bytearray`. assert recv_data == bytearray(send_data.get()) data["validator"](recv_data, send_data)
def _server(queue, server_close_callback): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) listener_finished = [False] closed = [False] # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. # ep = None def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) if server_close_callback is True: ep.set_close_callback(functools.partial(_close_callback, closed)) listener_finished[0] = True listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port) if server_close_callback is True: while closed[0] is False: worker.progress() assert closed[0] is True else: while listener_finished[0] is False: worker.progress()
def client(queue, port, server_address, args): if args.client_cpu_affinity >= 0: os.sched_setaffinity(0, [args.client_cpu_affinity]) import numpy as np if args.object_type == "numpy": import numpy as xp elif args.object_type == "cupy": import cupy as xp xp.cuda.runtime.setDevice(args.client_dev) else: import cupy as xp import rmm rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=args.rmm_init_pool_size, devices=[args.client_dev], ) xp.cuda.runtime.setDevice(args.client_dev) xp.cuda.set_allocator(rmm.rmm_cupy_allocator) ctx = ucx_api.UCXContext(feature_flags=( ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) register_am_allocators(args, worker) ep = ucx_api.UCXEndpoint.create( worker, server_address, port, endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0), ) send_msg = xp.arange(args.n_bytes, dtype="u1") if args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: wireup_recv = bytearray(len(WireupMessage)) blocking_send(worker, ep, WireupMessage) blocking_recv(worker, ep, wireup_recv) op_lock = Lock() finished = [0] outstanding = [0] def maybe_progress(): while outstanding[0] >= args.max_outstanding: worker.progress() def op_started(): with op_lock: outstanding[0] += 1 def op_completed(): with op_lock: outstanding[0] -= 1 finished[0] += 1 if args.cuda_profile: xp.cuda.profiler.start() times = [] for i in range(args.n_iter): start = clock() if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: if not args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.delay_progress: maybe_progress() non_blocking_send(worker, ep, send_msg, op_started, op_completed) maybe_progress() non_blocking_recv(worker, ep, recv_msg, op_started, op_completed) else: blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) stop = clock() times.append(stop - start) if args.delay_progress: while finished[0] != 2 * args.n_iter: worker.progress() if args.cuda_profile: xp.cuda.profiler.stop() assert len(times) == args.n_iter delay_progress_str = (f"True ({args.max_outstanding})" if args.delay_progress is True else "False") print("Roundtrip benchmark") print("--------------------------") print(f"n_iter | {args.n_iter}") print(f"n_bytes | {format_bytes(args.n_bytes)}") print(f"object | {args.object_type}") print(f"reuse alloc | {args.reuse_alloc}") print(f"transfer API | {'AM' if args.enable_am else 'TAG'}") print(f"delay progress | {delay_progress_str}") print(f"UCX_TLS | {ucp.get_config()['TLS']}") print(f"UCX_NET_DEVICES | {ucp.get_config()['NET_DEVICES']}") print("==========================") if args.object_type == "numpy": print("Device(s) | CPU-only") s_aff = (args.server_cpu_affinity if args.server_cpu_affinity >= 0 else "affinity not set") c_aff = (args.client_cpu_affinity if args.client_cpu_affinity >= 0 else "affinity not set") print(f"Server CPU | {s_aff}") print(f"Client CPU | {c_aff}") else: print(f"Device(s) | {args.server_dev}, {args.client_dev}") avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times)) med = format_bytes(2 * args.n_bytes / np.median(times)) print(f"Average | {avg}/s") print(f"Median | {med}/s") if not args.no_detailed_report: print("--------------------------") print("Iterations") print("--------------------------") for i, t in enumerate(times): ts = format_bytes(2 * args.n_bytes / t) ts = (" " * (9 - len(ts))) + ts print("%03d |%s/s" % (i, ts))
def client(queue, port, server_address, args): if args.client_cpu_affinity >= 0: os.sched_setaffinity(0, [args.client_cpu_affinity]) import numpy as np if args.object_type == "numpy": import numpy as xp elif args.object_type == "cupy": import cupy as xp xp.cuda.runtime.setDevice(args.client_dev) else: import cupy as xp import rmm rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=args.rmm_init_pool_size, devices=[args.client_dev], ) xp.cuda.runtime.setDevice(args.client_dev) xp.cuda.set_allocator(rmm.rmm_cupy_allocator) ctx = ucx_api.UCXContext( feature_flags=( ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG, ) ) worker = ucx_api.UCXWorker(ctx) register_am_allocators(args, worker) ep = ucx_api.UCXEndpoint.create( worker, server_address, port, endpoint_error_handling=True, ) send_msg = xp.arange(args.n_bytes, dtype="u1") if args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: wireup_recv = bytearray(len(WireupMessage)) blocking_send(worker, ep, WireupMessage) blocking_recv(worker, ep, wireup_recv) op_lock = Lock() finished = [0] outstanding = [0] def maybe_progress(): while outstanding[0] >= args.max_outstanding: worker.progress() def op_started(): with op_lock: outstanding[0] += 1 def op_completed(): with op_lock: outstanding[0] -= 1 finished[0] += 1 if args.cuda_profile: xp.cuda.profiler.start() times = [] for i in range(args.n_iter + args.n_warmup_iter): start = clock() if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: if not args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.delay_progress: maybe_progress() non_blocking_send(worker, ep, send_msg, op_started, op_completed) maybe_progress() non_blocking_recv(worker, ep, recv_msg, op_started, op_completed) else: blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) stop = clock() if i >= args.n_warmup_iter: times.append(stop - start) if args.delay_progress: while finished[0] != 2 * (args.n_iter + args.n_warmup_iter): worker.progress() if args.cuda_profile: xp.cuda.profiler.stop() assert len(times) == args.n_iter bw_avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times)) bw_med = format_bytes(2 * args.n_bytes / np.median(times)) lat_avg = int(sum(times) * 1e9 / (2 * args.n_iter)) lat_med = int(np.median(times) * 1e9 / 2) delay_progress_str = ( f"True ({args.max_outstanding})" if args.delay_progress is True else "False" ) print("Roundtrip benchmark") print_separator(separator="=") print_key_value(key="Iterations", value=f"{args.n_iter}") print_key_value(key="Bytes", value=f"{format_bytes(args.n_bytes)}") print_key_value(key="Object type", value=f"{args.object_type}") print_key_value(key="Reuse allocation", value=f"{args.reuse_alloc}") print_key_value(key="Transfer API", value=f"{'AM' if args.enable_am else 'TAG'}") print_key_value(key="Delay progress", value=f"{delay_progress_str}") print_key_value(key="UCX_TLS", value=f"{ucp.get_config()['TLS']}") print_key_value(key="UCX_NET_DEVICES", value=f"{ucp.get_config()['NET_DEVICES']}") print_separator(separator="=") if args.object_type == "numpy": print_key_value(key="Device(s)", value="CPU-only") s_aff = ( args.server_cpu_affinity if args.server_cpu_affinity >= 0 else "affinity not set" ) c_aff = ( args.client_cpu_affinity if args.client_cpu_affinity >= 0 else "affinity not set" ) print_key_value(key="Server CPU", value=f"{s_aff}") print_key_value(key="Client CPU", value=f"{c_aff}") else: print_key_value(key="Device(s)", value=f"{args.server_dev}, {args.client_dev}") print_separator(separator="=") print_key_value("Bandwidth (average)", value=f"{bw_avg}/s") print_key_value("Bandwidth (median)", value=f"{bw_med}/s") print_key_value("Latency (average)", value=f"{lat_avg} ns") print_key_value("Latency (median)", value=f"{lat_med} ns") if not args.no_detailed_report: print_separator(separator="=") print_key_value(key="Iterations", value="Bandwidth, Latency") print_separator(separator="-") for i, t in enumerate(times): ts = format_bytes(2 * args.n_bytes / t) lat = int(t * 1e9 / 2) print_key_value(key=i, value=f"{ts}/s, {lat}ns")
def server(queue, args): if args.server_cpu_affinity >= 0: os.sched_setaffinity(0, [args.server_cpu_affinity]) if args.object_type == "numpy": import numpy as xp elif args.object_type == "cupy": import cupy as xp xp.cuda.runtime.setDevice(args.server_dev) else: import cupy as xp import rmm rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=args.rmm_init_pool_size, devices=[args.server_dev], ) xp.cuda.runtime.setDevice(args.server_dev) xp.cuda.set_allocator(rmm.rmm_cupy_allocator) ctx = ucx_api.UCXContext( feature_flags=( ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG, ) ) worker = ucx_api.UCXWorker(ctx) register_am_allocators(args, worker) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None op_lock = Lock() finished = [0] outstanding = [0] def op_started(): with op_lock: outstanding[0] += 1 def op_completed(): with op_lock: outstanding[0] -= 1 finished[0] += 1 def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None op_completed() def _tag_recv_handle(request, exception, ep, msg): assert exception is None req = ucx_api.tag_send_nb( ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,) ) if req is None: op_completed() def _am_recv_handle(recv_obj, exception, ep): assert exception is None msg = Array(recv_obj) ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,)) def _listener_handler(conn_request, msg): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) # Wireup before starting to transfer data if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,)) else: wireup = Array(bytearray(len(WireupMessage))) op_started() ucx_api.tag_recv_nb( worker, wireup, wireup.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, wireup), ) for i in range(args.n_iter + args.n_warmup_iter): if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,)) else: if not args.reuse_alloc: msg = Array(xp.zeros(args.n_bytes, dtype="u1")) op_started() ucx_api.tag_recv_nb( worker, msg, msg.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, msg), ) if not args.enable_am and args.reuse_alloc: msg = Array(xp.zeros(args.n_bytes, dtype="u1")) else: msg = None listener = ucx_api.UCXListener( worker=worker, port=args.port or 0, cb_func=_listener_handler, cb_args=(msg,) ) queue.put(listener.port) while outstanding[0] == 0: worker.progress() # +1 to account for wireup message if args.delay_progress: while finished[0] < args.n_iter + args.n_warmup_iter + 1 and ( outstanding[0] >= args.max_outstanding or finished[0] + args.max_outstanding >= args.n_iter + args.n_warmup_iter + 1 ): worker.progress() else: while finished[0] != args.n_iter + args.n_warmup_iter + 1: worker.progress()