def test_feature_flags_mismatch(feature_flag): ctx = ucx_api.UCXContext(feature_flags=(feature_flag, )) worker = ucx_api.UCXWorker(ctx) addr = worker.get_address() ep = ucx_api.UCXEndpoint.create_from_worker_address( worker, addr, endpoint_error_handling=False) msg = Array(bytearray(10)) if feature_flag != ucx_api.Feature.TAG: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`"): ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`"): ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None) if feature_flag != ucx_api.Feature.STREAM: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.STREAM`"): ucx_api.stream_send_nb(ep, msg, msg.nbytes, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.STREAM`"): ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None) if feature_flag != ucx_api.Feature.AM: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.AM`"): ucx_api.am_send_nbx(ep, msg, msg.nbytes, None) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.AM`"): ucx_api.am_recv_nb(ep, None)
def blocking_am_recv(worker, ep): ret = [None] ucx_api.am_recv_nb( ep, cb_func=blocking_am_recv_handler, cb_args=(ret, ), ) while ret[0] is None: worker.progress() return ret[0]
def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=endpoint_error_handling, ) # Wireup ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep, )) # Data ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep, ))
def _client_cancel(queue, transfer_api): """Client that connects to server and waits for messages to be received, because the server closes without sending anything, the messages will trigger cancelation. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) port = queue.get() ep = ucx_api.UCXEndpoint.create( worker, get_address(), port, endpoint_error_handling=True, ) ret = [None] if transfer_api == "am": ucx_api.am_recv_nb(ep, cb_func=_handler, cb_args=(ret, )) match_msg = ".*am_recv.*" else: msg = Array(bytearray(1)) ucx_api.tag_recv_nb(worker, msg, msg.nbytes, tag=0, cb_func=_handler, cb_args=(ret, ), ep=ep) match_msg = ".*tag_recv_nb.*" while ep.is_alive(): worker.progress() canceled = worker.cancel_inflight_messages() while ret[0] is None: worker.progress() assert canceled == 1 assert isinstance(ret[0], UCXCanceled) assert re.match(match_msg, ret[0].args[0])
def _listener_handler(conn_request, msg): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) # Wireup before starting to transfer data if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,)) else: wireup = Array(bytearray(len(WireupMessage))) op_started() ucx_api.tag_recv_nb( worker, wireup, wireup.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, wireup), ) for i in range(args.n_iter + args.n_warmup_iter): if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,)) else: if not args.reuse_alloc: msg = Array(xp.zeros(args.n_bytes, dtype="u1")) op_started() ucx_api.tag_recv_nb( worker, msg, msg.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, msg), )