def _server_cancel(queue, transfer_api): """Server that establishes an endpoint to client and immediately closes it, triggering received messages to be canceled on the client. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback ep = [None] def _listener_handler(conn_request): ep[0] = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port) while ep[0] is None: worker.progress() ep[0].close() worker.progress()
def _server_probe(queue, transfer_api): """Server that probes and receives message after client disconnected. Note that since it is illegal to call progress() in callback functions, we keep a reference to the endpoint after the listener callback has terminated, this way we can progress even after Python blocking calls. """ feature_flags = (ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, ) ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback ep = [None] def _listener_handler(conn_request): ep[0] = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port), while ep[0] is None: worker.progress() ep = ep[0] # Ensure wireup and inform client before it can disconnect if transfer_api == "am": wireup = blocking_am_recv(worker, ep) else: wireup = bytearray(len(WireupMessage)) blocking_recv(worker, ep, wireup) queue.put("wireup completed") # Ensure client has disconnected -- endpoint is not alive anymore while ep.is_alive() is True: worker.progress() # Probe/receive message even after the remote endpoint has disconnected if transfer_api == "am": while ep.am_probe() is False: worker.progress() received = blocking_am_recv(worker, ep) else: while worker.tag_probe(0) is False: worker.progress() received = bytearray(len(DataMessage)) blocking_recv(worker, ep, received) assert wireup == WireupMessage assert received == DataMessage
def _echo_server(get_queue, put_queue, msg_size, datatype): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ data = get_data()[datatype] ctx = ucx_api.UCXContext( config_dict={"RNDV_THRESH": str(RNDV_THRESH)}, feature_flags=(ucx_api.Feature.AM,), ) worker = ucx_api.UCXWorker(ctx) worker.register_am_allocator(data["allocator"], data["memory_type"]) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None def _recv_handle(recv_obj, exception, ep): assert exception is None msg = Array(recv_obj) ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,)) def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) # Wireup ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) # Data ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) put_queue.put(listener.port) while True: worker.progress() try: get_queue.get(block=False, timeout=0.1) except QueueIsEmpty: continue else: break
def test_listener_ip_port(): ctx = ucx_api.UCXContext() worker = ucx_api.UCXWorker(ctx) def _listener_handler(conn_request): pass listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) assert isinstance(listener.ip, str) and listener.ip assert (isinstance(listener.port, int) and listener.port >= 0 and listener.port <= 65535)
def _echo_server(get_queue, put_queue, msg_size): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) worker = ucx_api.UCXWorker(ctx) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None def _recv_handle(request, exception, ep, msg): assert exception is None ucx_api.tag_send_nb( ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,) ) def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) msg = Array(bytearray(msg_size)) ucx_api.tag_recv_nb( worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg) ) listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) put_queue.put(listener.port) while True: worker.progress() try: get_queue.get(block=False, timeout=0.1) except QueueIsEmpty: continue else: break
def _server(queue, server_close_callback): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, we use a "chain" of call-back functions. """ ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, )) worker = ucx_api.UCXWorker(ctx) listener_finished = [False] closed = [False] # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. # ep = None def _listener_handler(conn_request): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) if server_close_callback is True: ep.set_close_callback(functools.partial(_close_callback, closed)) listener_finished[0] = True listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port) if server_close_callback is True: while closed[0] is False: worker.progress() assert closed[0] is True else: while listener_finished[0] is False: worker.progress()
def server(queue, args): if args.server_cpu_affinity >= 0: os.sched_setaffinity(0, [args.server_cpu_affinity]) if args.object_type == "numpy": import numpy as xp elif args.object_type == "cupy": import cupy as xp xp.cuda.runtime.setDevice(args.server_dev) else: import cupy as xp import rmm rmm.reinitialize( pool_allocator=True, managed_memory=False, initial_pool_size=args.rmm_init_pool_size, devices=[args.server_dev], ) xp.cuda.runtime.setDevice(args.server_dev) xp.cuda.set_allocator(rmm.rmm_cupy_allocator) ctx = ucx_api.UCXContext( feature_flags=( ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG, ) ) worker = ucx_api.UCXWorker(ctx) register_am_allocators(args, worker) # A reference to listener's endpoint is stored to prevent it from going # out of scope too early. ep = None op_lock = Lock() finished = [0] outstanding = [0] def op_started(): with op_lock: outstanding[0] += 1 def op_completed(): with op_lock: outstanding[0] -= 1 finished[0] += 1 def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. assert exception is None op_completed() def _tag_recv_handle(request, exception, ep, msg): assert exception is None req = ucx_api.tag_send_nb( ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,) ) if req is None: op_completed() def _am_recv_handle(recv_obj, exception, ep): assert exception is None msg = Array(recv_obj) ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,)) def _listener_handler(conn_request, msg): global ep ep = ucx_api.UCXEndpoint.create_from_conn_request( worker, conn_request, endpoint_error_handling=True, ) # Wireup before starting to transfer data if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,)) else: wireup = Array(bytearray(len(WireupMessage))) op_started() ucx_api.tag_recv_nb( worker, wireup, wireup.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, wireup), ) for i in range(args.n_iter + args.n_warmup_iter): if args.enable_am is True: ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,)) else: if not args.reuse_alloc: msg = Array(xp.zeros(args.n_bytes, dtype="u1")) op_started() ucx_api.tag_recv_nb( worker, msg, msg.nbytes, tag=0, cb_func=_tag_recv_handle, cb_args=(ep, msg), ) if not args.enable_am and args.reuse_alloc: msg = Array(xp.zeros(args.n_bytes, dtype="u1")) else: msg = None listener = ucx_api.UCXListener( worker=worker, port=args.port or 0, cb_func=_listener_handler, cb_args=(msg,) ) queue.put(listener.port) while outstanding[0] == 0: worker.progress() # +1 to account for wireup message if args.delay_progress: while finished[0] < args.n_iter + args.n_warmup_iter + 1 and ( outstanding[0] >= args.max_outstanding or finished[0] + args.max_outstanding >= args.n_iter + args.n_warmup_iter + 1 ): worker.progress() else: while finished[0] != args.n_iter + args.n_warmup_iter + 1: worker.progress()