Пример #1
0
def _server_cancel(queue, transfer_api):
    """Server that establishes an endpoint to client and immediately closes
    it, triggering received messages to be canceled on the client.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)

    # Keep endpoint to be used from outside the listener callback
    ep = [None]

    def _listener_handler(conn_request):
        ep[0] = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port)

    while ep[0] is None:
        worker.progress()

    ep[0].close()
    worker.progress()
Пример #2
0
def test_feature_flags_mismatch(feature_flag):
    ctx = ucx_api.UCXContext(feature_flags=(feature_flag, ))
    worker = ucx_api.UCXWorker(ctx)
    addr = worker.get_address()
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker, addr, endpoint_error_handling=False)
    msg = Array(bytearray(10))
    if feature_flag != ucx_api.Feature.TAG:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None)
    if feature_flag != ucx_api.Feature.STREAM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_send_nb(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None)
    if feature_flag != ucx_api.Feature.AM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_send_nbx(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_recv_nb(ep, None)
Пример #3
0
def test_force_requests():
    msg_size = 1024
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)
    self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size)

    counter = 0
    send_msg = bytes(os.urandom(msg_size))
    req = self_mem.put_nb(send_msg, _)
    while req is None:
        counter = counter + 1
        req = self_mem.put_nb(send_msg, _)
        # This `if` is here because some combinations of transports, such as
        # normal desktop PCs, will never have their transports exhausted. So
        # we have a break to make sure this test still completes
        if counter > 10000:
            pytest.xfail("Could not generate a request")

    blocking_flush(worker)
    while worker.progress():
        pass

    while self_mem.put_nb(send_msg, _):
        pass
    blocking_flush(worker)
Пример #4
0
def _test_peer_communication_tag(queue, rank, msg_size):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)
    queue.put((rank, worker.get_address()))
    right_rank, right_address = queue.get()
    left_rank, left_address = queue.get()

    right_ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        right_address,
        endpoint_error_handling=True,
    )
    left_ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        left_address,
        endpoint_error_handling=True,
    )
    recv_msg = bytearray(msg_size)
    if rank == 0:
        send_msg = bytes(os.urandom(msg_size))
        blocking_send(worker, right_ep, send_msg, right_rank)
        blocking_recv(worker, left_ep, recv_msg, rank)
        assert send_msg == recv_msg
    else:
        blocking_recv(worker, left_ep, recv_msg, rank)
        blocking_send(worker, right_ep, recv_msg, right_rank)
Пример #5
0
def test_ucxio_seek_bad(seek_loc, seek_flag):
    msg_size = 1024
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)

    uio = ucx_api.UCXIO(mem.address, msg_size, rkey)
    send_msg = bytes(os.urandom(msg_size))
    uio.write(send_msg)
    uio.seek(seek_loc, seek_flag)
    recv_msg = uio.read(len(send_msg))

    if seek_loc > 0:
        expected = b""
    else:
        expected = send_msg

    assert recv_msg == expected
Пример #6
0
def _server_probe(queue, transfer_api):
    """Server that probes and receives message after client disconnected.

    Note that since it is illegal to call progress() in callback functions,
    we keep a reference to the endpoint after the listener callback has
    terminated, this way we can progress even after Python blocking calls.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)

    # Keep endpoint to be used from outside the listener callback
    ep = [None]

    def _listener_handler(conn_request):
        ep[0] = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port),

    while ep[0] is None:
        worker.progress()

    ep = ep[0]

    # Ensure wireup and inform client before it can disconnect
    if transfer_api == "am":
        wireup = blocking_am_recv(worker, ep)
    else:
        wireup = bytearray(len(WireupMessage))
        blocking_recv(worker, ep, wireup)
    queue.put("wireup completed")

    # Ensure client has disconnected -- endpoint is not alive anymore
    while ep.is_alive() is True:
        worker.progress()

    # Probe/receive message even after the remote endpoint has disconnected
    if transfer_api == "am":
        while ep.am_probe() is False:
            worker.progress()
        received = blocking_am_recv(worker, ep)
    else:
        while worker.tag_probe(0) is False:
            worker.progress()
        received = bytearray(len(DataMessage))
        blocking_recv(worker, ep, received)

    assert wireup == WireupMessage
    assert received == DataMessage
Пример #7
0
def test_pickle_ucx_address():
    ctx = ucx_api.UCXContext()
    worker = ucx_api.UCXWorker(ctx)
    org_address = worker.get_address()
    dumped_address = pickle.dumps(org_address)
    org_address_hash = hash(org_address)
    org_address = bytes(org_address)
    new_address = pickle.loads(dumped_address)
    assert org_address_hash == hash(new_address)
    assert bytes(org_address) == bytes(new_address)
Пример #8
0
def _echo_server(get_queue, put_queue, msg_size, datatype):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    data = get_data()[datatype]

    ctx = ucx_api.UCXContext(
        config_dict={"RNDV_THRESH": str(RNDV_THRESH)},
        feature_flags=(ucx_api.Feature.AM,),
    )
    worker = ucx_api.UCXWorker(ctx)
    worker.register_am_allocator(data["allocator"], data["memory_type"])

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None

    def _recv_handle(recv_obj, exception, ep):
        assert exception is None
        msg = Array(recv_obj)
        ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup
        ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

        # Data
        ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

    listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
    put_queue.put(listener.port)

    while True:
        worker.progress()
        try:
            get_queue.get(block=False, timeout=0.1)
        except QueueIsEmpty:
            continue
        else:
            break
Пример #9
0
def _echo_client(msg_size, port, endpoint_error_handling):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create(
        worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
    )
    send_msg = bytes(os.urandom(msg_size))
    recv_msg = bytearray(msg_size)
    blocking_send(worker, ep, send_msg)
    blocking_recv(worker, ep, recv_msg)
    assert send_msg == recv_msg
Пример #10
0
def test_rkey_unpack():
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=get_endpoint_error_handling_default(),
    )
    rkey = ep.unpack_rkey(packed_rkey)
    assert rkey is not None
Пример #11
0
def test_listener_ip_port():
    ctx = ucx_api.UCXContext()
    worker = ucx_api.UCXWorker(ctx)

    def _listener_handler(conn_request):
        pass

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)

    assert isinstance(listener.ip, str) and listener.ip
    assert (isinstance(listener.port, int) and listener.port >= 0
            and listener.port <= 65535)
Пример #12
0
def test_flush():
    ctx = ucx_api.UCXContext({})
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    req = ep.flush(_)
    if req is None:
        info = req.info
        while info["status"] == "pending":
            worker.progress()
        assert info["status"] == "finished"
Пример #13
0
def _echo_server(get_queue, put_queue, msg_size):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
    worker = ucx_api.UCXWorker(ctx)

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None

    def _recv_handle(request, exception, ep, msg):
        assert exception is None
        ucx_api.tag_send_nb(
            ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,)
        )

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )
        msg = Array(bytearray(msg_size))
        ucx_api.tag_recv_nb(
            worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg)
        )

    listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
    put_queue.put(listener.port)

    while True:
        worker.progress()
        try:
            get_queue.get(block=False, timeout=0.1)
        except QueueIsEmpty:
            continue
        else:
            break
Пример #14
0
def _client_cancel(queue, transfer_api):
    """Client that connects to server and waits for messages to be received,
    because the server closes without sending anything, the messages will
    trigger cancelation.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)
    port = queue.get()
    ep = ucx_api.UCXEndpoint.create(
        worker,
        get_address(),
        port,
        endpoint_error_handling=True,
    )

    ret = [None]

    if transfer_api == "am":
        ucx_api.am_recv_nb(ep, cb_func=_handler, cb_args=(ret, ))

        match_msg = ".*am_recv.*"
    else:
        msg = Array(bytearray(1))
        ucx_api.tag_recv_nb(worker,
                            msg,
                            msg.nbytes,
                            tag=0,
                            cb_func=_handler,
                            cb_args=(ret, ),
                            ep=ep)

        match_msg = ".*tag_recv_nb.*"

    while ep.is_alive():
        worker.progress()

    canceled = worker.cancel_inflight_messages()

    while ret[0] is None:
        worker.progress()

    assert canceled == 1
    assert isinstance(ret[0], UCXCanceled)
    assert re.match(match_msg, ret[0].args[0])
Пример #15
0
def _client(port, server_close_callback):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create(
        worker,
        ucx_api.get_address(),
        port,
        endpoint_error_handling=True,
    )
    if server_close_callback is True:
        ep.close()
        worker.progress()
    else:
        closed = [False]
        ep.set_close_callback(functools.partial(_close_callback, closed))
        while closed[0] is False:
            worker.progress()
Пример #16
0
def test_ucxio(msg_size):
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)

    uio = ucx_api.UCXIO(mem.address, msg_size, rkey)
    send_msg = bytes(os.urandom(msg_size))
    uio.write(send_msg)
    uio.seek(0)
    recv_msg = uio.read(msg_size)
    assert send_msg == recv_msg
Пример #17
0
def _test_peer_communication_rma(queue, rank, msg_size):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.RMA,
                                            ucx_api.Feature.TAG))
    worker = ucx_api.UCXWorker(ctx)
    self_address = worker.get_address()
    mem_handle = ctx.alloc(msg_size)
    self_base = mem_handle.address
    self_prkey = mem_handle.pack_rkey()

    self_ep, self_mem = _rma_setup(worker, self_address, self_prkey, self_base,
                                   msg_size)
    send_msg = bytes(repeat(rank, msg_size))
    if not self_mem.put_nbi(send_msg):
        blocking_flush(self_ep)

    queue.put((rank, self_address, self_prkey, self_base))
    right_rank, right_address, right_prkey, right_base = queue.get()
    left_rank, left_address, left_prkey, left_base = queue.get()

    right_ep, right_mem = _rma_setup(worker, right_address, right_prkey,
                                     right_base, msg_size)
    right_msg = bytearray(msg_size)
    right_mem.get_nbi(right_msg)

    left_ep, left_mem = _rma_setup(worker, left_address, left_prkey, left_base,
                                   msg_size)
    left_msg = bytearray(msg_size)
    left_mem.get_nbi(left_msg)

    blocking_flush(worker)
    assert left_msg == bytes(repeat(left_rank, msg_size))
    assert right_msg == bytes(repeat(right_rank, msg_size))

    # We use the blocking tag send/recv as a barrier implementation
    recv_msg = bytearray(8)
    if rank == 0:
        send_msg = bytes(os.urandom(8))
        blocking_send(worker, right_ep, send_msg, right_rank)
        blocking_recv(worker, left_ep, recv_msg, rank)
    else:
        blocking_recv(worker, left_ep, recv_msg, rank)
        blocking_send(worker, right_ep, recv_msg, right_rank)
Пример #18
0
def test_implicit(msg_size):
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)
    self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size)

    send_msg = bytes(os.urandom(msg_size))
    if not self_mem.put_nbi(send_msg):
        blocking_flush(ep)
    recv_msg = bytearray(len(send_msg))
    if not self_mem.get_nbi(recv_msg):
        blocking_flush(ep)
    assert send_msg == recv_msg
Пример #19
0
def _client_probe(queue, transfer_api):
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)
    port = queue.get()
    ep = ucx_api.UCXEndpoint.create(
        worker,
        get_address(),
        port,
        endpoint_error_handling=True,
    )

    _send = blocking_am_send if transfer_api == "am" else blocking_send

    _send(worker, ep, WireupMessage)
    _send(worker, ep, DataMessage)

    # Wait for wireup before disconnecting
    assert queue.get() == "wireup completed"
Пример #20
0
def test_ucxio_seek_good(seek_data):
    seek_flag, seek_dest = seek_data
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024)
    msg_size = mem.length
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)

    uio = ucx_api.UCXIO(mem.address, msg_size, rkey)
    send_msg = bytes(os.urandom(msg_size))
    uio.write(send_msg)
    uio.seek(seek_dest, seek_flag)
    recv_msg = uio.read(4)

    assert recv_msg == send_msg[seek_dest:seek_dest + 4]
Пример #21
0
def _echo_client(msg_size, datatype, port, endpoint_error_handling):
    data = get_data()[datatype]

    ctx = ucx_api.UCXContext(
        config_dict={"RNDV_THRESH": str(RNDV_THRESH)},
        feature_flags=(ucx_api.Feature.AM, ),
    )
    worker = ucx_api.UCXWorker(ctx)
    worker.register_am_allocator(data["allocator"], data["memory_type"])

    ep = ucx_api.UCXEndpoint.create(
        worker,
        "localhost",
        port,
        endpoint_error_handling=endpoint_error_handling,
    )

    # The wireup message is sent to ensure endpoints are connected, otherwise
    # UCX may not perform any rendezvous transfers.
    send_wireup = bytearray(b"wireup")
    send_data = data["generator"](msg_size)

    blocking_am_send(worker, ep, send_wireup)
    blocking_am_send(worker, ep, send_data)

    recv_wireup = blocking_am_recv(worker, ep)
    recv_data = blocking_am_recv(worker, ep)

    # Cast recv_wireup to bytearray when using NumPy as a host allocator,
    # this ensures the assertion below is correct
    if datatype == "numpy":
        recv_wireup = bytearray(recv_wireup)
    assert bytearray(recv_wireup) == send_wireup

    if data["memory_type"] == "cuda" and send_data.nbytes < RNDV_THRESH:
        # Eager messages are always received on the host, if no host
        # allocator is registered UCX-Py defaults to `bytearray`.
        assert recv_data == bytearray(send_data.get())
        data["validator"](recv_data, send_data)
Пример #22
0
def _server(queue, server_close_callback):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)

    listener_finished = [False]
    closed = [False]

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    # ep = None

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )
        if server_close_callback is True:
            ep.set_close_callback(functools.partial(_close_callback, closed))
        listener_finished[0] = True

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port)

    if server_close_callback is True:
        while closed[0] is False:
            worker.progress()
        assert closed[0] is True
    else:
        while listener_finished[0] is False:
            worker.progress()
Пример #23
0
def client(queue, port, server_address, args):
    if args.client_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.client_cpu_affinity])

    import numpy as np

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.client_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.client_dev],
        )
        xp.cuda.runtime.setDevice(args.client_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(feature_flags=(
        ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
    ))
    worker = ucx_api.UCXWorker(ctx)
    register_am_allocators(args, worker)
    ep = ucx_api.UCXEndpoint.create(
        worker,
        server_address,
        port,
        endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0),
    )

    send_msg = xp.arange(args.n_bytes, dtype="u1")
    if args.reuse_alloc:
        recv_msg = xp.zeros(args.n_bytes, dtype="u1")

    if args.enable_am:
        blocking_am_send(worker, ep, send_msg)
        blocking_am_recv(worker, ep)
    else:
        wireup_recv = bytearray(len(WireupMessage))
        blocking_send(worker, ep, WireupMessage)
        blocking_recv(worker, ep, wireup_recv)

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def maybe_progress():
        while outstanding[0] >= args.max_outstanding:
            worker.progress()

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    if args.cuda_profile:
        xp.cuda.profiler.start()

    times = []
    for i in range(args.n_iter):
        start = clock()

        if args.enable_am:
            blocking_am_send(worker, ep, send_msg)
            blocking_am_recv(worker, ep)
        else:
            if not args.reuse_alloc:
                recv_msg = xp.zeros(args.n_bytes, dtype="u1")

            if args.delay_progress:
                maybe_progress()
                non_blocking_send(worker, ep, send_msg, op_started,
                                  op_completed)
                maybe_progress()
                non_blocking_recv(worker, ep, recv_msg, op_started,
                                  op_completed)
            else:
                blocking_send(worker, ep, send_msg)
                blocking_recv(worker, ep, recv_msg)

        stop = clock()
        times.append(stop - start)

    if args.delay_progress:
        while finished[0] != 2 * args.n_iter:
            worker.progress()

    if args.cuda_profile:
        xp.cuda.profiler.stop()

    assert len(times) == args.n_iter
    delay_progress_str = (f"True ({args.max_outstanding})"
                          if args.delay_progress is True else "False")
    print("Roundtrip benchmark")
    print("--------------------------")
    print(f"n_iter          | {args.n_iter}")
    print(f"n_bytes         | {format_bytes(args.n_bytes)}")
    print(f"object          | {args.object_type}")
    print(f"reuse alloc     | {args.reuse_alloc}")
    print(f"transfer API    | {'AM' if args.enable_am else 'TAG'}")
    print(f"delay progress  | {delay_progress_str}")
    print(f"UCX_TLS         | {ucp.get_config()['TLS']}")
    print(f"UCX_NET_DEVICES | {ucp.get_config()['NET_DEVICES']}")
    print("==========================")
    if args.object_type == "numpy":
        print("Device(s)       | CPU-only")
        s_aff = (args.server_cpu_affinity
                 if args.server_cpu_affinity >= 0 else "affinity not set")
        c_aff = (args.client_cpu_affinity
                 if args.client_cpu_affinity >= 0 else "affinity not set")
        print(f"Server CPU      | {s_aff}")
        print(f"Client CPU      | {c_aff}")
    else:
        print(f"Device(s)       | {args.server_dev}, {args.client_dev}")
    avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times))
    med = format_bytes(2 * args.n_bytes / np.median(times))
    print(f"Average         | {avg}/s")
    print(f"Median          | {med}/s")
    if not args.no_detailed_report:
        print("--------------------------")
        print("Iterations")
        print("--------------------------")
        for i, t in enumerate(times):
            ts = format_bytes(2 * args.n_bytes / t)
            ts = (" " * (9 - len(ts))) + ts
            print("%03d         |%s/s" % (i, ts))
Пример #24
0
def client(queue, port, server_address, args):
    if args.client_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.client_cpu_affinity])

    import numpy as np

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.client_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.client_dev],
        )
        xp.cuda.runtime.setDevice(args.client_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(
        feature_flags=(
            ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
        )
    )
    worker = ucx_api.UCXWorker(ctx)
    register_am_allocators(args, worker)
    ep = ucx_api.UCXEndpoint.create(
        worker,
        server_address,
        port,
        endpoint_error_handling=True,
    )

    send_msg = xp.arange(args.n_bytes, dtype="u1")
    if args.reuse_alloc:
        recv_msg = xp.zeros(args.n_bytes, dtype="u1")

    if args.enable_am:
        blocking_am_send(worker, ep, send_msg)
        blocking_am_recv(worker, ep)
    else:
        wireup_recv = bytearray(len(WireupMessage))
        blocking_send(worker, ep, WireupMessage)
        blocking_recv(worker, ep, wireup_recv)

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def maybe_progress():
        while outstanding[0] >= args.max_outstanding:
            worker.progress()

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    if args.cuda_profile:
        xp.cuda.profiler.start()

    times = []
    for i in range(args.n_iter + args.n_warmup_iter):
        start = clock()

        if args.enable_am:
            blocking_am_send(worker, ep, send_msg)
            blocking_am_recv(worker, ep)
        else:
            if not args.reuse_alloc:
                recv_msg = xp.zeros(args.n_bytes, dtype="u1")

            if args.delay_progress:
                maybe_progress()
                non_blocking_send(worker, ep, send_msg, op_started, op_completed)
                maybe_progress()
                non_blocking_recv(worker, ep, recv_msg, op_started, op_completed)
            else:
                blocking_send(worker, ep, send_msg)
                blocking_recv(worker, ep, recv_msg)

        stop = clock()
        if i >= args.n_warmup_iter:
            times.append(stop - start)

    if args.delay_progress:
        while finished[0] != 2 * (args.n_iter + args.n_warmup_iter):
            worker.progress()

    if args.cuda_profile:
        xp.cuda.profiler.stop()

    assert len(times) == args.n_iter
    bw_avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times))
    bw_med = format_bytes(2 * args.n_bytes / np.median(times))
    lat_avg = int(sum(times) * 1e9 / (2 * args.n_iter))
    lat_med = int(np.median(times) * 1e9 / 2)

    delay_progress_str = (
        f"True ({args.max_outstanding})" if args.delay_progress is True else "False"
    )

    print("Roundtrip benchmark")
    print_separator(separator="=")
    print_key_value(key="Iterations", value=f"{args.n_iter}")
    print_key_value(key="Bytes", value=f"{format_bytes(args.n_bytes)}")
    print_key_value(key="Object type", value=f"{args.object_type}")
    print_key_value(key="Reuse allocation", value=f"{args.reuse_alloc}")
    print_key_value(key="Transfer API", value=f"{'AM' if args.enable_am else 'TAG'}")
    print_key_value(key="Delay progress", value=f"{delay_progress_str}")
    print_key_value(key="UCX_TLS", value=f"{ucp.get_config()['TLS']}")
    print_key_value(key="UCX_NET_DEVICES", value=f"{ucp.get_config()['NET_DEVICES']}")
    print_separator(separator="=")
    if args.object_type == "numpy":
        print_key_value(key="Device(s)", value="CPU-only")
        s_aff = (
            args.server_cpu_affinity
            if args.server_cpu_affinity >= 0
            else "affinity not set"
        )
        c_aff = (
            args.client_cpu_affinity
            if args.client_cpu_affinity >= 0
            else "affinity not set"
        )
        print_key_value(key="Server CPU", value=f"{s_aff}")
        print_key_value(key="Client CPU", value=f"{c_aff}")
    else:
        print_key_value(key="Device(s)", value=f"{args.server_dev}, {args.client_dev}")
    print_separator(separator="=")
    print_key_value("Bandwidth (average)", value=f"{bw_avg}/s")
    print_key_value("Bandwidth (median)", value=f"{bw_med}/s")
    print_key_value("Latency (average)", value=f"{lat_avg} ns")
    print_key_value("Latency (median)", value=f"{lat_med} ns")
    if not args.no_detailed_report:
        print_separator(separator="=")
        print_key_value(key="Iterations", value="Bandwidth, Latency")
        print_separator(separator="-")
        for i, t in enumerate(times):
            ts = format_bytes(2 * args.n_bytes / t)
            lat = int(t * 1e9 / 2)
            print_key_value(key=i, value=f"{ts}/s, {lat}ns")
Пример #25
0
def server(queue, args):
    if args.server_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.server_cpu_affinity])

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.server_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.server_dev],
        )
        xp.cuda.runtime.setDevice(args.server_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(
        feature_flags=(
            ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
        )
    )
    worker = ucx_api.UCXWorker(ctx)

    register_am_allocators(args, worker)

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None
        op_completed()

    def _tag_recv_handle(request, exception, ep, msg):
        assert exception is None
        req = ucx_api.tag_send_nb(
            ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,)
        )
        if req is None:
            op_completed()

    def _am_recv_handle(recv_obj, exception, ep):
        assert exception is None
        msg = Array(recv_obj)
        ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

    def _listener_handler(conn_request, msg):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup before starting to transfer data
        if args.enable_am is True:
            ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
        else:
            wireup = Array(bytearray(len(WireupMessage)))
            op_started()
            ucx_api.tag_recv_nb(
                worker,
                wireup,
                wireup.nbytes,
                tag=0,
                cb_func=_tag_recv_handle,
                cb_args=(ep, wireup),
            )

        for i in range(args.n_iter + args.n_warmup_iter):
            if args.enable_am is True:
                ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
            else:
                if not args.reuse_alloc:
                    msg = Array(xp.zeros(args.n_bytes, dtype="u1"))

                op_started()
                ucx_api.tag_recv_nb(
                    worker,
                    msg,
                    msg.nbytes,
                    tag=0,
                    cb_func=_tag_recv_handle,
                    cb_args=(ep, msg),
                )

    if not args.enable_am and args.reuse_alloc:
        msg = Array(xp.zeros(args.n_bytes, dtype="u1"))
    else:
        msg = None

    listener = ucx_api.UCXListener(
        worker=worker, port=args.port or 0, cb_func=_listener_handler, cb_args=(msg,)
    )
    queue.put(listener.port)

    while outstanding[0] == 0:
        worker.progress()

    # +1 to account for wireup message
    if args.delay_progress:
        while finished[0] < args.n_iter + args.n_warmup_iter + 1 and (
            outstanding[0] >= args.max_outstanding
            or finished[0] + args.max_outstanding
            >= args.n_iter + args.n_warmup_iter + 1
        ):
            worker.progress()
    else:
        while finished[0] != args.n_iter + args.n_warmup_iter + 1:
            worker.progress()