def test_close_callback(server_close_callback): endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0) queue = mp.Queue() server = mp.Process( target=_server, args=(queue, endpoint_error_handling, server_close_callback), ) server.start() port = queue.get() client = mp.Process( target=_client, args=(port, endpoint_error_handling, server_close_callback), ) client.start() client.join(timeout=10) server.join(timeout=10) assert client.exitcode == 0 assert server.exitcode == 0
def test_server_client(msg_size): endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0) put_queue, get_queue = mp.Queue(), mp.Queue() server = mp.Process( target=_echo_server, args=(put_queue, get_queue, msg_size, endpoint_error_handling), ) server.start() port = get_queue.get() client = mp.Process( target=_echo_client, args=(msg_size, port, endpoint_error_handling) ) client.start() client.join(timeout=10) assert not client.exitcode put_queue.put("Finished") server.join(timeout=10) assert not server.exitcode
def _listener_handler(conn_request, msg): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0), ) # Wireup before starting to transfer data if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep, )) else: wireup = Array(bytearray(len(WireupMessage))) op_started() ucx_api.tag_recv_nb( worker, wireup, wireup.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, wireup), ) for i in range(args.n_iter): if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep, )) else: if not args.reuse_alloc: msg = Array(xp.zeros(args.n_bytes, dtype="u1")) op_started() ucx_api.tag_recv_nb( worker, msg, msg.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, msg), )
os.environ["UCX_SEG_SIZE"] = "2M" ctx = ucx_api.UCXContext() config = ctx.get_config() assert config["SEG_SIZE"] == os.environ["UCX_SEG_SIZE"] def test_init_options(): os.environ["UCX_SEG_SIZE"] = "2M" # Should be ignored options = {"SEG_SIZE": "3M"} ctx = ucx_api.UCXContext(options) config = ctx.get_config() assert config["SEG_SIZE"] == options["SEG_SIZE"] @pytest.mark.skipif( ucx_api.get_ucx_version() >= (1, 12, 0), reason="Beginning with UCX >= 1.12, it's only possible to validate " "UCP options but not options from other modules such as UCT. " "See https://github.com/openucx/ucx/issues/7519.", ) def test_init_unknown_option(): options = {"UNKNOWN_OPTION": "3M"} with pytest.raises(UCXConfigError): ucx_api.UCXContext(options) def test_init_invalid_option(): options = {"SEG_SIZE": "invalid-size"} with pytest.raises(UCXConfigError): ucx_api.UCXContext(options)
def get_endpoint_error_handling_default(): return ucx_api.get_ucx_version() >= (1, 10, 0)
def client(queue, port, server_address, args): if args.client_cpu_affinity >= 0: os.sched_setaffinity(0, [args.client_cpu_affinity]) import numpy as np if args.object_type == "numpy": import numpy as xp elif args.object_type == "cupy": import cupy as xp xp.cuda.runtime.setDevice(args.client_dev) else: import cupy as xp import rmm rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=args.rmm_init_pool_size, devices=[args.client_dev], ) xp.cuda.runtime.setDevice(args.client_dev) xp.cuda.set_allocator(rmm.rmm_cupy_allocator) ctx = ucx_api.UCXContext(feature_flags=( ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) register_am_allocators(args, worker) ep = ucx_api.UCXEndpoint.create( worker, server_address, port, endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0), ) send_msg = xp.arange(args.n_bytes, dtype="u1") if args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: wireup_recv = bytearray(len(WireupMessage)) blocking_send(worker, ep, WireupMessage) blocking_recv(worker, ep, wireup_recv) op_lock = Lock() finished = [0] outstanding = [0] def maybe_progress(): while outstanding[0] >= args.max_outstanding: worker.progress() def op_started(): with op_lock: outstanding[0] += 1 def op_completed(): with op_lock: outstanding[0] -= 1 finished[0] += 1 if args.cuda_profile: xp.cuda.profiler.start() times = [] for i in range(args.n_iter): start = clock() if args.enable_am: blocking_am_send(worker, ep, send_msg) blocking_am_recv(worker, ep) else: if not args.reuse_alloc: recv_msg = xp.zeros(args.n_bytes, dtype="u1") if args.delay_progress: maybe_progress() non_blocking_send(worker, ep, send_msg, op_started, op_completed) maybe_progress() non_blocking_recv(worker, ep, recv_msg, op_started, op_completed) else: blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) stop = clock() times.append(stop - start) if args.delay_progress: while finished[0] != 2 * args.n_iter: worker.progress() if args.cuda_profile: xp.cuda.profiler.stop() assert len(times) == args.n_iter delay_progress_str = (f"True ({args.max_outstanding})" if args.delay_progress is True else "False") print("Roundtrip benchmark") print("--------------------------") print(f"n_iter | {args.n_iter}") print(f"n_bytes | {format_bytes(args.n_bytes)}") print(f"object | {args.object_type}") print(f"reuse alloc | {args.reuse_alloc}") print(f"transfer API | {'AM' if args.enable_am else 'TAG'}") print(f"delay progress | {delay_progress_str}") print(f"UCX_TLS | {ucp.get_config()['TLS']}") print(f"UCX_NET_DEVICES | {ucp.get_config()['NET_DEVICES']}") print("==========================") if args.object_type == "numpy": print("Device(s) | CPU-only") s_aff = (args.server_cpu_affinity if args.server_cpu_affinity >= 0 else "affinity not set") c_aff = (args.client_cpu_affinity if args.client_cpu_affinity >= 0 else "affinity not set") print(f"Server CPU | {s_aff}") print(f"Client CPU | {c_aff}") else: print(f"Device(s) | {args.server_dev}, {args.client_dev}") avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times)) med = format_bytes(2 * args.n_bytes / np.median(times)) print(f"Average | {avg}/s") print(f"Median | {med}/s") if not args.no_detailed_report: print("--------------------------") print("Iterations") print("--------------------------") for i, t in enumerate(times): ts = format_bytes(2 * args.n_bytes / t) ts = (" " * (9 - len(ts))) + ts print("%03d |%s/s" % (i, ts))