Exemple #1
0
def _server_cancel(queue, transfer_api):
    """Server that establishes an endpoint to client and immediately closes
    it, triggering received messages to be canceled on the client.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)

    # Keep endpoint to be used from outside the listener callback
    ep = [None]

    def _listener_handler(conn_request):
        ep[0] = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port)

    while ep[0] is None:
        worker.progress()

    ep[0].close()
    worker.progress()
def _test_peer_communication_tag(queue, rank, msg_size):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)
    queue.put((rank, worker.get_address()))
    right_rank, right_address = queue.get()
    left_rank, left_address = queue.get()

    right_ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        right_address,
        endpoint_error_handling=True,
    )
    left_ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        left_address,
        endpoint_error_handling=True,
    )
    recv_msg = bytearray(msg_size)
    if rank == 0:
        send_msg = bytes(os.urandom(msg_size))
        blocking_send(worker, right_ep, send_msg, right_rank)
        blocking_recv(worker, left_ep, recv_msg, rank)
        assert send_msg == recv_msg
    else:
        blocking_recv(worker, left_ep, recv_msg, rank)
        blocking_send(worker, right_ep, recv_msg, right_rank)
Exemple #3
0
def test_force_requests():
    msg_size = 1024
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)
    self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size)

    counter = 0
    send_msg = bytes(os.urandom(msg_size))
    req = self_mem.put_nb(send_msg, _)
    while req is None:
        counter = counter + 1
        req = self_mem.put_nb(send_msg, _)
        # This `if` is here because some combinations of transports, such as
        # normal desktop PCs, will never have their transports exhausted. So
        # we have a break to make sure this test still completes
        if counter > 10000:
            pytest.xfail("Could not generate a request")

    blocking_flush(worker)
    while worker.progress():
        pass

    while self_mem.put_nb(send_msg, _):
        pass
    blocking_flush(worker)
Exemple #4
0
def test_feature_flags_mismatch(feature_flag):
    ctx = ucx_api.UCXContext(feature_flags=(feature_flag, ))
    worker = ucx_api.UCXWorker(ctx)
    addr = worker.get_address()
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker, addr, endpoint_error_handling=False)
    msg = Array(bytearray(10))
    if feature_flag != ucx_api.Feature.TAG:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None)
    if feature_flag != ucx_api.Feature.STREAM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_send_nb(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None)
    if feature_flag != ucx_api.Feature.AM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_send_nbx(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_recv_nb(ep, None)
Exemple #5
0
def test_ucxio_seek_bad(seek_loc, seek_flag):
    msg_size = 1024
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)

    uio = ucx_api.UCXIO(mem.address, msg_size, rkey)
    send_msg = bytes(os.urandom(msg_size))
    uio.write(send_msg)
    uio.seek(seek_loc, seek_flag)
    recv_msg = uio.read(len(send_msg))

    if seek_loc > 0:
        expected = b""
    else:
        expected = send_msg

    assert recv_msg == expected
Exemple #6
0
def _server_probe(queue, transfer_api):
    """Server that probes and receives message after client disconnected.

    Note that since it is illegal to call progress() in callback functions,
    we keep a reference to the endpoint after the listener callback has
    terminated, this way we can progress even after Python blocking calls.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)

    # Keep endpoint to be used from outside the listener callback
    ep = [None]

    def _listener_handler(conn_request):
        ep[0] = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port),

    while ep[0] is None:
        worker.progress()

    ep = ep[0]

    # Ensure wireup and inform client before it can disconnect
    if transfer_api == "am":
        wireup = blocking_am_recv(worker, ep)
    else:
        wireup = bytearray(len(WireupMessage))
        blocking_recv(worker, ep, wireup)
    queue.put("wireup completed")

    # Ensure client has disconnected -- endpoint is not alive anymore
    while ep.is_alive() is True:
        worker.progress()

    # Probe/receive message even after the remote endpoint has disconnected
    if transfer_api == "am":
        while ep.am_probe() is False:
            worker.progress()
        received = blocking_am_recv(worker, ep)
    else:
        while worker.tag_probe(0) is False:
            worker.progress()
        received = bytearray(len(DataMessage))
        blocking_recv(worker, ep, received)

    assert wireup == WireupMessage
    assert received == DataMessage
def test_pickle_ucx_address():
    ctx = ucx_api.UCXContext()
    worker = ucx_api.UCXWorker(ctx)
    org_address = worker.get_address()
    dumped_address = pickle.dumps(org_address)
    org_address_hash = hash(org_address)
    org_address = bytes(org_address)
    new_address = pickle.loads(dumped_address)
    assert org_address_hash == hash(new_address)
    assert bytes(org_address) == bytes(new_address)
Exemple #8
0
def _echo_server(get_queue, put_queue, msg_size, datatype):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    data = get_data()[datatype]

    ctx = ucx_api.UCXContext(
        config_dict={"RNDV_THRESH": str(RNDV_THRESH)},
        feature_flags=(ucx_api.Feature.AM,),
    )
    worker = ucx_api.UCXWorker(ctx)
    worker.register_am_allocator(data["allocator"], data["memory_type"])

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None

    def _recv_handle(recv_obj, exception, ep):
        assert exception is None
        msg = Array(recv_obj)
        ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup
        ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

        # Data
        ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

    listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
    put_queue.put(listener.port)

    while True:
        worker.progress()
        try:
            get_queue.get(block=False, timeout=0.1)
        except QueueIsEmpty:
            continue
        else:
            break
Exemple #9
0
def _echo_client(msg_size, port, endpoint_error_handling):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create(
        worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
    )
    send_msg = bytes(os.urandom(msg_size))
    recv_msg = bytearray(msg_size)
    blocking_send(worker, ep, send_msg)
    blocking_recv(worker, ep, recv_msg)
    assert send_msg == recv_msg
Exemple #10
0
def test_rkey_unpack():
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=get_endpoint_error_handling_default(),
    )
    rkey = ep.unpack_rkey(packed_rkey)
    assert rkey is not None
Exemple #11
0
def test_listener_ip_port():
    ctx = ucx_api.UCXContext()
    worker = ucx_api.UCXWorker(ctx)

    def _listener_handler(conn_request):
        pass

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)

    assert isinstance(listener.ip, str) and listener.ip
    assert (isinstance(listener.port, int) and listener.port >= 0
            and listener.port <= 65535)
Exemple #12
0
def test_flush():
    ctx = ucx_api.UCXContext({})
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    req = ep.flush(_)
    if req is None:
        info = req.info
        while info["status"] == "pending":
            worker.progress()
        assert info["status"] == "finished"
Exemple #13
0
def test_get_config():
    # Cache user-defined UCX_TLS and unset it to test default value
    tls = os.environ.get("UCX_TLS", None)
    if tls is not None:
        del os.environ["UCX_TLS"]

    ctx = ucx_api.UCXContext()
    config = ctx.get_config()
    assert isinstance(config, dict)
    assert config["TLS"] == "all"

    # Restore user-defined UCX_TLS
    if tls is not None:
        os.environ["UCX_TLS"] = tls
Exemple #14
0
def _echo_server(get_queue, put_queue, msg_size):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
    worker = ucx_api.UCXWorker(ctx)

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None

    def _recv_handle(request, exception, ep, msg):
        assert exception is None
        ucx_api.tag_send_nb(
            ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,)
        )

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )
        msg = Array(bytearray(msg_size))
        ucx_api.tag_recv_nb(
            worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg)
        )

    listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
    put_queue.put(listener.port)

    while True:
        worker.progress()
        try:
            get_queue.get(block=False, timeout=0.1)
        except QueueIsEmpty:
            continue
        else:
            break
Exemple #15
0
def _client(port, server_close_callback):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create(
        worker,
        ucx_api.get_address(),
        port,
        endpoint_error_handling=True,
    )
    if server_close_callback is True:
        ep.close()
        worker.progress()
    else:
        closed = [False]
        ep.set_close_callback(functools.partial(_close_callback, closed))
        while closed[0] is False:
            worker.progress()
Exemple #16
0
def _client_cancel(queue, transfer_api):
    """Client that connects to server and waits for messages to be received,
    because the server closes without sending anything, the messages will
    trigger cancelation.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)
    port = queue.get()
    ep = ucx_api.UCXEndpoint.create(
        worker,
        get_address(),
        port,
        endpoint_error_handling=True,
    )

    ret = [None]

    if transfer_api == "am":
        ucx_api.am_recv_nb(ep, cb_func=_handler, cb_args=(ret, ))

        match_msg = ".*am_recv.*"
    else:
        msg = Array(bytearray(1))
        ucx_api.tag_recv_nb(worker,
                            msg,
                            msg.nbytes,
                            tag=0,
                            cb_func=_handler,
                            cb_args=(ret, ),
                            ep=ep)

        match_msg = ".*tag_recv_nb.*"

    while ep.is_alive():
        worker.progress()

    canceled = worker.cancel_inflight_messages()

    while ret[0] is None:
        worker.progress()

    assert canceled == 1
    assert isinstance(ret[0], UCXCanceled)
    assert re.match(match_msg, ret[0].args[0])
Exemple #17
0
def test_ucxio(msg_size):
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)

    uio = ucx_api.UCXIO(mem.address, msg_size, rkey)
    send_msg = bytes(os.urandom(msg_size))
    uio.write(send_msg)
    uio.seek(0)
    recv_msg = uio.read(msg_size)
    assert send_msg == recv_msg
def _test_peer_communication_rma(queue, rank, msg_size):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.RMA,
                                            ucx_api.Feature.TAG))
    worker = ucx_api.UCXWorker(ctx)
    self_address = worker.get_address()
    mem_handle = ctx.alloc(msg_size)
    self_base = mem_handle.address
    self_prkey = mem_handle.pack_rkey()

    self_ep, self_mem = _rma_setup(worker, self_address, self_prkey, self_base,
                                   msg_size)
    send_msg = bytes(repeat(rank, msg_size))
    if not self_mem.put_nbi(send_msg):
        blocking_flush(self_ep)

    queue.put((rank, self_address, self_prkey, self_base))
    right_rank, right_address, right_prkey, right_base = queue.get()
    left_rank, left_address, left_prkey, left_base = queue.get()

    right_ep, right_mem = _rma_setup(worker, right_address, right_prkey,
                                     right_base, msg_size)
    right_msg = bytearray(msg_size)
    right_mem.get_nbi(right_msg)

    left_ep, left_mem = _rma_setup(worker, left_address, left_prkey, left_base,
                                   msg_size)
    left_msg = bytearray(msg_size)
    left_mem.get_nbi(left_msg)

    blocking_flush(worker)
    assert left_msg == bytes(repeat(left_rank, msg_size))
    assert right_msg == bytes(repeat(right_rank, msg_size))

    # We use the blocking tag send/recv as a barrier implementation
    recv_msg = bytearray(8)
    if rank == 0:
        send_msg = bytes(os.urandom(8))
        blocking_send(worker, right_ep, send_msg, right_rank)
        blocking_recv(worker, left_ep, recv_msg, rank)
    else:
        blocking_recv(worker, left_ep, recv_msg, rank)
        blocking_send(worker, right_ep, recv_msg, right_rank)
Exemple #19
0
def _client_probe(queue, transfer_api):
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)
    port = queue.get()
    ep = ucx_api.UCXEndpoint.create(
        worker,
        get_address(),
        port,
        endpoint_error_handling=True,
    )

    _send = blocking_am_send if transfer_api == "am" else blocking_send

    _send(worker, ep, WireupMessage)
    _send(worker, ep, DataMessage)

    # Wait for wireup before disconnecting
    assert queue.get() == "wireup completed"
Exemple #20
0
def test_implicit(msg_size):
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, msg_size)
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)
    self_mem = ucx_api.RemoteMemory(rkey, mem.address, msg_size)

    send_msg = bytes(os.urandom(msg_size))
    if not self_mem.put_nbi(send_msg):
        blocking_flush(ep)
    recv_msg = bytearray(len(send_msg))
    if not self_mem.get_nbi(recv_msg):
        blocking_flush(ep)
    assert send_msg == recv_msg
Exemple #21
0
def test_ucxio_seek_good(seek_data):
    seek_flag, seek_dest = seek_data
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024)
    msg_size = mem.length
    packed_rkey = mem.pack_rkey()
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        worker.get_address(),
        endpoint_error_handling=True,
    )
    rkey = ep.unpack_rkey(packed_rkey)

    uio = ucx_api.UCXIO(mem.address, msg_size, rkey)
    send_msg = bytes(os.urandom(msg_size))
    uio.write(send_msg)
    uio.seek(seek_dest, seek_flag)
    recv_msg = uio.read(4)

    assert recv_msg == send_msg[seek_dest:seek_dest + 4]
Exemple #22
0
def _echo_client(msg_size, datatype, port, endpoint_error_handling):
    data = get_data()[datatype]

    ctx = ucx_api.UCXContext(
        config_dict={"RNDV_THRESH": str(RNDV_THRESH)},
        feature_flags=(ucx_api.Feature.AM, ),
    )
    worker = ucx_api.UCXWorker(ctx)
    worker.register_am_allocator(data["allocator"], data["memory_type"])

    ep = ucx_api.UCXEndpoint.create(
        worker,
        "localhost",
        port,
        endpoint_error_handling=endpoint_error_handling,
    )

    # The wireup message is sent to ensure endpoints are connected, otherwise
    # UCX may not perform any rendezvous transfers.
    send_wireup = bytearray(b"wireup")
    send_data = data["generator"](msg_size)

    blocking_am_send(worker, ep, send_wireup)
    blocking_am_send(worker, ep, send_data)

    recv_wireup = blocking_am_recv(worker, ep)
    recv_data = blocking_am_recv(worker, ep)

    # Cast recv_wireup to bytearray when using NumPy as a host allocator,
    # this ensures the assertion below is correct
    if datatype == "numpy":
        recv_wireup = bytearray(recv_wireup)
    assert bytearray(recv_wireup) == send_wireup

    if data["memory_type"] == "cuda" and send_data.nbytes < RNDV_THRESH:
        # Eager messages are always received on the host, if no host
        # allocator is registered UCX-Py defaults to `bytearray`.
        assert recv_data == bytearray(send_data.get())
        data["validator"](recv_data, send_data)
Exemple #23
0
def _server(queue, server_close_callback):
    """Server that send received message back to the client

    Notice, since it is illegal to call progress() in call-back functions,
    we use a "chain" of call-back functions.
    """
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)

    listener_finished = [False]
    closed = [False]

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    # ep = None

    def _listener_handler(conn_request):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )
        if server_close_callback is True:
            ep.set_close_callback(functools.partial(_close_callback, closed))
        listener_finished[0] = True

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port)

    if server_close_callback is True:
        while closed[0] is False:
            worker.progress()
        assert closed[0] is True
    else:
        while listener_finished[0] is False:
            worker.progress()
Exemple #24
0
def test_ctx_map(buffer):
    ctx = ucx_api.UCXContext({})
    mem = ctx.map(buffer)
    rkey = mem.pack_rkey()
    assert rkey is not None
Exemple #25
0
def test_init_unknown_option():
    options = {"UNKNOWN_OPTION": "3M"}
    with pytest.raises(UCXConfigError):
        ucx_api.UCXContext(options)
Exemple #26
0
def test_map(buffer):
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.map(ctx, buffer)
    rkey = mem.pack_rkey()
    assert rkey is not None
Exemple #27
0
def test_alloc():
    ctx = ucx_api.UCXContext({})
    mem = ucx_api.UCXMemoryHandle.alloc(ctx, 1024)
    rkey = mem.pack_rkey()
    assert rkey is not None
Exemple #28
0
def test_init_options():
    os.environ["UCX_SEG_SIZE"] = "2M"  # Should be ignored
    options = {"SEG_SIZE": "3M"}
    ctx = ucx_api.UCXContext(options)
    config = ctx.get_config()
    assert config["SEG_SIZE"] == options["SEG_SIZE"]
Exemple #29
0
def test_ctx_alloc():
    ctx = ucx_api.UCXContext({})
    mem = ctx.alloc(1024)
    rkey = mem.pack_rkey()
    assert rkey is not None
Exemple #30
0
def test_init_invalid_option():
    options = {"SEG_SIZE": "invalid-size"}
    with pytest.raises(UCXConfigError):
        ucx_api.UCXContext(options)