def _server_cancel(queue, transfer_api): """Server that establishes an endpoint to client and immediately closes it, triggering received messages to be canceled on the client. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback ep = [None] def _listener_handler(conn_request): ep[0] = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port) while ep[0] is None: worker.progress() ep[0].close() worker.progress()
def _test_peer_communication_tag(queue, rank, msg_size): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) queue.put((rank, worker.get_address())) right_rank, right_address = queue.get() left_rank, left_address = queue.get() right_ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, right_address, endpoint_error_handling=True, ) left_ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, left_address, endpoint_error_handling=True, ) recv_msg = bytearray(msg_size) if rank == 0: send_msg = bytes(os.urandom(msg_size)) blocking_send(worker, right_ep, send_msg, right_rank) blocking_recv(worker, left_ep, recv_msg, rank) assert send_msg == recv_msg else: blocking_recv(worker, left_ep, recv_msg, rank) blocking_send(worker, right_ep, recv_msg, right_rank)
def test_force_requests(): msg_size = 1024 ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size) counter = 0 send_msg = bytes(os.urandom(msg_size)) req = self_mem.put_nb(send_msg, _) while req is None: counter = counter + 1 req = self_mem.put_nb(send_msg, _) # This `if` is here because some combinations of transports, such as # normal desktop PCs, will never have their transports exhausted. So # we have a break to make sure this test still completes if counter > 10000: pytest.xfail("Could not generate a request") blocking_flush(worker) while worker.progress(): pass while self_mem.put_nb(send_msg, _): pass blocking_flush(worker)
def test_feature_flags_mismatch(feature_flag): ctx = ucx_api.UCXContext(feature_flags=(feature_flag, )) worker = ucx_api.UCXWorker(ctx) addr = worker.get_address() ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, addr, endpoint_error_handling=False) msg = Array(bytearray(10)) if feature_flag != ucx_api.Feature.TAG: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`"): ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`"): ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None) if feature_flag != ucx_api.Feature.STREAM: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.STREAM`"): ucx_api.stream_send_nb(ep, msg, msg.nbytes, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.STREAM`"): ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None) if feature_flag != ucx_api.Feature.AM: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.AM`"): ucx_api.am_send_nbx(ep, msg, msg.nbytes, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.AM`"): ucx_api.am_recv_nb(ep, None)
def test_ucxio_seek_bad(seek_loc, seek_flag): msg_size = 1024 ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) uio = ucx_api.UCXIO(mem.address, msg_size, rkey) send_msg = bytes(os.urandom(msg_size)) uio.write(send_msg) uio.seek(seek_loc, seek_flag) recv_msg = uio.read(len(send_msg)) if seek_loc > 0: expected = b"" else: expected = send_msg assert recv_msg == expected
def _server_probe(queue, transfer_api): """Server that probes and receives message after client disconnected. Note that since it is illegal to call progress() in callback functions, we keep a reference to the endpoint after the listener callback has terminated, this way we can progress even after Python blocking calls. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback ep = [None] def _listener_handler(conn_request): ep[0] = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port), while ep[0] is None: worker.progress() ep = ep[0] # Ensure wireup and inform client before it can disconnect if transfer_api == "am": wireup = blocking_am_recv(worker, ep) else: wireup = bytearray(len(WireupMessage)) blocking_recv(worker, ep, wireup) queue.put("wireup completed") # Ensure client has disconnected -- endpoint is not alive anymore while ep.is_alive() is True: worker.progress() # Probe/receive message even after the remote endpoint has disconnected if transfer_api == "am": while ep.am_probe() is False: worker.progress() received = blocking_am_recv(worker, ep) else: while worker.tag_probe(0) is False: worker.progress() received = bytearray(len(DataMessage)) blocking_recv(worker, ep, received) assert wireup == WireupMessage assert received == DataMessage
def test_pickle_ucx_address(): ctx = ucx_api.UCXContext() worker = ucx_api.UCXWorker(ctx) org_address = worker.get_address() dumped_address = pickle.dumps(org_address) org_address_hash = hash(org_address) org_address = bytes(org_address) new_address = pickle.loads(dumped_address) assert org_address_hash == hash(new_address) assert bytes(org_address) == bytes(new_address)
def _echo_server(get_queue, put_queue, msg_size, datatype): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ data = get_data()[datatype] ctx = ucx_api.UCXContext( config_dict={"RNDV_THRESH": str(RNDV_THRESH)}, feature_flags=(ucx_api.Feature.AM,), ) worker = ucx_api.UCXWorker(ctx) worker.register_am_allocator(data["allocator"], data["memory_type"]) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None def _recv_handle(recv_obj, exception, ep): assert exception is None msg = Array(recv_obj) ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,)) def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) # Wireup ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) # Data ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) put_queue.put(listener.port) while True: worker.progress() try: get_queue.get(block=False, timeout=0.1) except QueueIsEmpty: continue else: break
def _echo_client(msg_size, port, endpoint_error_handling): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create( worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, ) send_msg = bytes(os.urandom(msg_size)) recv_msg = bytearray(msg_size) blocking_send(worker, ep, send_msg) blocking_recv(worker, ep, recv_msg) assert send_msg == recv_msg
def test_rkey_unpack(): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=get_endpoint_error_handling_default(), ) rkey = ep.unpack_rkey(packed_rkey) assert rkey is not None
def test_listener_ip_port(): ctx = ucx_api.UCXContext() worker = ucx_api.UCXWorker(ctx) def _listener_handler(conn_request): pass listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) assert isinstance(listener.ip, str) and listener.ip assert (isinstance(listener.port, int) and listener.port >= 0 and listener.port <= 65535)
def test_flush(): ctx = ucx_api.UCXContext({}) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) req = ep.flush(_) if req is None: info = req.info while info["status"] == "pending": worker.progress() assert info["status"] == "finished"
def test_get_config(): # Cache user-defined UCX_TLS and unset it to test default value tls = os.environ.get("UCX_TLS", None) if tls is not None: del os.environ["UCX_TLS"] ctx = ucx_api.UCXContext() config = ctx.get_config() assert isinstance(config, dict) assert config["TLS"] == "all" # Restore user-defined UCX_TLS if tls is not None: os.environ["UCX_TLS"] = tls
def _echo_server(get_queue, put_queue, msg_size): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) worker = ucx_api.UCXWorker(ctx) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None def _recv_handle(request, exception, ep, msg): assert exception is None ucx_api.tag_send_nb( ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,) ) def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) msg = Array(bytearray(msg_size)) ucx_api.tag_recv_nb( worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg) ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) put_queue.put(listener.port) while True: worker.progress() try: get_queue.get(block=False, timeout=0.1) except QueueIsEmpty: continue else: break
def _client(port, server_close_callback): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create( worker, ucx_api.get_address(), port, endpoint_error_handling=True, ) if server_close_callback is True: ep.close() worker.progress() else: closed = [False] ep.set_close_callback(functools.partial(_close_callback, closed)) while closed[0] is False: worker.progress()
def _client_cancel(queue, transfer_api): """Client that connects to server and waits for messages to be received, because the server closes without sending anything, the messages will trigger cancelation. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) port = queue.get() ep = ucx_api.UCXEndpoint.create( worker, get_address(), port, endpoint_error_handling=True, ) ret = [None] if transfer_api == "am": ucx_api.am_recv_nb(ep, cb_func=_handler, cb_args=(ret, )) match_msg = ".*am_recv.*" else: msg = Array(bytearray(1)) ucx_api.tag_recv_nb(worker, msg, msg.nbytes, tag=0, cb_func=_handler, cb_args=(ret, ), ep=ep) match_msg = ".*tag_recv_nb.*" while ep.is_alive(): worker.progress() canceled = worker.cancel_inflight_messages() while ret[0] is None: worker.progress() assert canceled == 1 assert isinstance(ret[0], UCXCanceled) assert re.match(match_msg, ret[0].args[0])
def test_ucxio(msg_size): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) uio = ucx_api.UCXIO(mem.address, msg_size, rkey) send_msg = bytes(os.urandom(msg_size)) uio.write(send_msg) uio.seek(0) recv_msg = uio.read(msg_size) assert send_msg == recv_msg
def _test_peer_communication_rma(queue, rank, msg_size): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.RMA, ucx_api.Feature.TAG)) worker = ucx_api.UCXWorker(ctx) self_address = worker.get_address() mem_handle = ctx.alloc(msg_size) self_base = mem_handle.address self_prkey = mem_handle.pack_rkey() self_ep, self_mem = _rma_setup(worker, self_address, self_prkey, self_base, msg_size) send_msg = bytes(repeat(rank, msg_size)) if not self_mem.put_nbi(send_msg): blocking_flush(self_ep) queue.put((rank, self_address, self_prkey, self_base)) right_rank, right_address, right_prkey, right_base = queue.get() left_rank, left_address, left_prkey, left_base = queue.get() right_ep, right_mem = _rma_setup(worker, right_address, right_prkey, right_base, msg_size) right_msg = bytearray(msg_size) right_mem.get_nbi(right_msg) left_ep, left_mem = _rma_setup(worker, left_address, left_prkey, left_base, msg_size) left_msg = bytearray(msg_size) left_mem.get_nbi(left_msg) blocking_flush(worker) assert left_msg == bytes(repeat(left_rank, msg_size)) assert right_msg == bytes(repeat(right_rank, msg_size)) # We use the blocking tag send/recv as a barrier implementation recv_msg = bytearray(8) if rank == 0: send_msg = bytes(os.urandom(8)) blocking_send(worker, right_ep, send_msg, right_rank) blocking_recv(worker, left_ep, recv_msg, rank) else: blocking_recv(worker, left_ep, recv_msg, rank) blocking_send(worker, right_ep, recv_msg, right_rank)
def _client_probe(queue, transfer_api): feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) port = queue.get() ep = ucx_api.UCXEndpoint.create( worker, get_address(), port, endpoint_error_handling=True, ) _send = blocking_am_send if transfer_api == "am" else blocking_send _send(worker, ep, WireupMessage) _send(worker, ep, DataMessage) # Wait for wireup before disconnecting assert queue.get() == "wireup completed"
def test_implicit(msg_size): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size) packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size) send_msg = bytes(os.urandom(msg_size)) if not self_mem.put_nbi(send_msg): blocking_flush(ep) recv_msg = bytearray(len(send_msg)) if not self_mem.get_nbi(recv_msg): blocking_flush(ep) assert send_msg == recv_msg
def test_ucxio_seek_good(seek_data): seek_flag, seek_dest = seek_data ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024) msg_size = mem.length packed_rkey = mem.pack_rkey() worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, worker.get_address(), endpoint_error_handling=True, ) rkey = ep.unpack_rkey(packed_rkey) uio = ucx_api.UCXIO(mem.address, msg_size, rkey) send_msg = bytes(os.urandom(msg_size)) uio.write(send_msg) uio.seek(seek_dest, seek_flag) recv_msg = uio.read(4) assert recv_msg == send_msg[seek_dest:seek_dest + 4]
def _echo_client(msg_size, datatype, port, endpoint_error_handling): data = get_data()[datatype] ctx = ucx_api.UCXContext( config_dict={"RNDV_THRESH": str(RNDV_THRESH)}, feature_flags=(ucx_api.Feature.AM, ), ) worker = ucx_api.UCXWorker(ctx) worker.register_am_allocator(data["allocator"], data["memory_type"]) ep = ucx_api.UCXEndpoint.create( worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, ) # The wireup message is sent to ensure endpoints are connected, otherwise # UCX may not perform any rendezvous transfers. send_wireup = bytearray(b"wireup") send_data = data["generator"](msg_size) blocking_am_send(worker, ep, send_wireup) blocking_am_send(worker, ep, send_data) recv_wireup = blocking_am_recv(worker, ep) recv_data = blocking_am_recv(worker, ep) # Cast recv_wireup to bytearray when using NumPy as a host allocator, # this ensures the assertion below is correct if datatype == "numpy": recv_wireup = bytearray(recv_wireup) assert bytearray(recv_wireup) == send_wireup if data["memory_type"] == "cuda" and send_data.nbytes < RNDV_THRESH: # Eager messages are always received on the host, if no host # allocator is registered UCX-Py defaults to `bytearray`. assert recv_data == bytearray(send_data.get()) data["validator"](recv_data, send_data)
def _server(queue, server_close_callback): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) listener_finished = [False] closed = [False] # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. # ep = None def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) if server_close_callback is True: ep.set_close_callback(functools.partial(_close_callback, closed)) listener_finished[0] = True listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port) if server_close_callback is True: while closed[0] is False: worker.progress() assert closed[0] is True else: while listener_finished[0] is False: worker.progress()
def test_ctx_map(buffer): ctx = ucx_api.UCXContext({}) mem = ctx.map(buffer) rkey = mem.pack_rkey() assert rkey is not None
def test_init_unknown_option(): options = {"UNKNOWN_OPTION": "3M"} with pytest.raises(UCXConfigError): ucx_api.UCXContext(options)
def test_map(buffer): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.map(ctx, buffer) rkey = mem.pack_rkey() assert rkey is not None
def test_alloc(): ctx = ucx_api.UCXContext({}) mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024) rkey = mem.pack_rkey() assert rkey is not None
def test_init_options(): os.environ["UCX_SEG_SIZE"] = "2M" # Should be ignored options = {"SEG_SIZE": "3M"} ctx = ucx_api.UCXContext(options) config = ctx.get_config() assert config["SEG_SIZE"] == options["SEG_SIZE"]
def test_ctx_alloc(): ctx = ucx_api.UCXContext({}) mem = ctx.alloc(1024) rkey = mem.pack_rkey() assert rkey is not None
def test_init_invalid_option(): options = {"SEG_SIZE": "invalid-size"} with pytest.raises(UCXConfigError): ucx_api.UCXContext(options)