Esempio n. 1
0
def _server_cancel(queue, transfer_api):
    """Server that establishes an endpoint to client and immediately closes
    it, triggering received messages to be canceled on the client.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)

    # Keep endpoint to be used from outside the listener callback
    ep = [None]

    def _listener_handler(conn_request):
        ep[0] = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port)

    while ep[0] is None:
        worker.progress()

    ep[0].close()
    worker.progress()
Esempio n. 2
0
def _server_probe(queue, transfer_api):
    """Server that probes and receives message after client disconnected.

    Note that since it is illegal to call progress() in callback functions,
    we keep a reference to the endpoint after the listener callback has
    terminated, this way we can progress even after Python blocking calls.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)

    # Keep endpoint to be used from outside the listener callback
    ep = [None]

    def _listener_handler(conn_request):
        ep[0] = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port),

    while ep[0] is None:
        worker.progress()

    ep = ep[0]

    # Ensure wireup and inform client before it can disconnect
    if transfer_api == "am":
        wireup = blocking_am_recv(worker, ep)
    else:
        wireup = bytearray(len(WireupMessage))
        blocking_recv(worker, ep, wireup)
    queue.put("wireup completed")

    # Ensure client has disconnected -- endpoint is not alive anymore
    while ep.is_alive() is True:
        worker.progress()

    # Probe/receive message even after the remote endpoint has disconnected
    if transfer_api == "am":
        while ep.am_probe() is False:
            worker.progress()
        received = blocking_am_recv(worker, ep)
    else:
        while worker.tag_probe(0) is False:
            worker.progress()
        received = bytearray(len(DataMessage))
        blocking_recv(worker, ep, received)

    assert wireup == WireupMessage
    assert received == DataMessage
Esempio n. 3
0
def _echo_server(get_queue, put_queue, msg_size, datatype):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    data = get_data()[datatype]

    ctx = ucx_api.UCXContext(
        config_dict={"RNDV_THRESH": str(RNDV_THRESH)},
        feature_flags=(ucx_api.Feature.AM,),
    )
    worker = ucx_api.UCXWorker(ctx)
    worker.register_am_allocator(data["allocator"], data["memory_type"])

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None

    def _recv_handle(recv_obj, exception, ep):
        assert exception is None
        msg = Array(recv_obj)
        ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup
        ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

        # Data
        ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

    listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
    put_queue.put(listener.port)

    while True:
        worker.progress()
        try:
            get_queue.get(block=False, timeout=0.1)
        except QueueIsEmpty:
            continue
        else:
            break
Esempio n. 4
0
def test_listener_ip_port():
    ctx = ucx_api.UCXContext()
    worker = ucx_api.UCXWorker(ctx)

    def _listener_handler(conn_request):
        pass

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)

    assert isinstance(listener.ip, str) and listener.ip
    assert (isinstance(listener.port, int) and listener.port >= 0
            and listener.port <= 65535)
Esempio n. 5
0
def _echo_server(get_queue, put_queue, msg_size):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
    worker = ucx_api.UCXWorker(ctx)

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None

    def _recv_handle(request, exception, ep, msg):
        assert exception is None
        ucx_api.tag_send_nb(
            ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,)
        )

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )
        msg = Array(bytearray(msg_size))
        ucx_api.tag_recv_nb(
            worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg)
        )

    listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
    put_queue.put(listener.port)

    while True:
        worker.progress()
        try:
            get_queue.get(block=False, timeout=0.1)
        except QueueIsEmpty:
            continue
        else:
            break
Esempio n. 6
0
def _server(queue, server_close_callback):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)

    listener_finished = [False]
    closed = [False]

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    # ep = None

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )
        if server_close_callback is True:
            ep.set_close_callback(functools.partial(_close_callback, closed))
        listener_finished[0] = True

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port)

    if server_close_callback is True:
        while closed[0] is False:
            worker.progress()
        assert closed[0] is True
    else:
        while listener_finished[0] is False:
            worker.progress()
Esempio n. 7
0
def server(queue, args):
    if args.server_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.server_cpu_affinity])

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.server_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.server_dev],
        )
        xp.cuda.runtime.setDevice(args.server_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(
        feature_flags=(
            ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
        )
    )
    worker = ucx_api.UCXWorker(ctx)

    register_am_allocators(args, worker)

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None
        op_completed()

    def _tag_recv_handle(request, exception, ep, msg):
        assert exception is None
        req = ucx_api.tag_send_nb(
            ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,)
        )
        if req is None:
            op_completed()

    def _am_recv_handle(recv_obj, exception, ep):
        assert exception is None
        msg = Array(recv_obj)
        ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

    def _listener_handler(conn_request, msg):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup before starting to transfer data
        if args.enable_am is True:
            ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
        else:
            wireup = Array(bytearray(len(WireupMessage)))
            op_started()
            ucx_api.tag_recv_nb(
                worker,
                wireup,
                wireup.nbytes,
                tag=0,
                cb_func=_tag_recv_handle,
                cb_args=(ep, wireup),
            )

        for i in range(args.n_iter + args.n_warmup_iter):
            if args.enable_am is True:
                ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
            else:
                if not args.reuse_alloc:
                    msg = Array(xp.zeros(args.n_bytes, dtype="u1"))

                op_started()
                ucx_api.tag_recv_nb(
                    worker,
                    msg,
                    msg.nbytes,
                    tag=0,
                    cb_func=_tag_recv_handle,
                    cb_args=(ep, msg),
                )

    if not args.enable_am and args.reuse_alloc:
        msg = Array(xp.zeros(args.n_bytes, dtype="u1"))
    else:
        msg = None

    listener = ucx_api.UCXListener(
        worker=worker, port=args.port or 0, cb_func=_listener_handler, cb_args=(msg,)
    )
    queue.put(listener.port)

    while outstanding[0] == 0:
        worker.progress()

    # +1 to account for wireup message
    if args.delay_progress:
        while finished[0] < args.n_iter + args.n_warmup_iter + 1 and (
            outstanding[0] >= args.max_outstanding
            or finished[0] + args.max_outstanding
            >= args.n_iter + args.n_warmup_iter + 1
        ):
            worker.progress()
    else:
        while finished[0] != args.n_iter + args.n_warmup_iter + 1:
            worker.progress()