Esempio n. 1
0
def test_feature_flags_mismatch(feature_flag):
    ctx = ucx_api.UCXContext(feature_flags=(feature_flag, ))
    worker = ucx_api.UCXWorker(ctx)
    addr = worker.get_address()
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker, addr, endpoint_error_handling=False)
    msg = Array(bytearray(10))
    if feature_flag != ucx_api.Feature.TAG:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None)
    if feature_flag != ucx_api.Feature.STREAM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_send_nb(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None)
    if feature_flag != ucx_api.Feature.AM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_send_nbx(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_recv_nb(ep, None)
Esempio n. 2
0
 def _listener_handler(conn_request):
     global ep
     ep = ucx_api.UCXEndpoint.create_from_conn_request(
         worker, conn_request, endpoint_error_handling=endpoint_error_handling,
     )
     msg = Array(bytearray(msg_size))
     ucx_api.tag_recv_nb(
         worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg)
     )
Esempio n. 3
0
def _client_cancel(queue, transfer_api):
    """Client that connects to server and waits for messages to be received,
    because the server closes without sending anything, the messages will
    trigger cancelation.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)
    port = queue.get()
    ep = ucx_api.UCXEndpoint.create(
        worker,
        get_address(),
        port,
        endpoint_error_handling=True,
    )

    ret = [None]

    if transfer_api == "am":
        ucx_api.am_recv_nb(ep, cb_func=_handler, cb_args=(ret, ))

        match_msg = ".*am_recv.*"
    else:
        msg = Array(bytearray(1))
        ucx_api.tag_recv_nb(worker,
                            msg,
                            msg.nbytes,
                            tag=0,
                            cb_func=_handler,
                            cb_args=(ret, ),
                            ep=ep)

        match_msg = ".*tag_recv_nb.*"

    while ep.is_alive():
        worker.progress()

    canceled = worker.cancel_inflight_messages()

    while ret[0] is None:
        worker.progress()

    assert canceled == 1
    assert isinstance(ret[0], UCXCanceled)
    assert re.match(match_msg, ret[0].args[0])
Esempio n. 4
0
    def _listener_handler(conn_request, msg):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup before starting to transfer data
        if args.enable_am is True:
            ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
        else:
            wireup = Array(bytearray(len(WireupMessage)))
            op_started()
            ucx_api.tag_recv_nb(
                worker,
                wireup,
                wireup.nbytes,
                tag=0,
                cb_func=_tag_recv_handle,
                cb_args=(ep, wireup),
            )

        for i in range(args.n_iter + args.n_warmup_iter):
            if args.enable_am is True:
                ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
            else:
                if not args.reuse_alloc:
                    msg = Array(xp.zeros(args.n_bytes, dtype="u1"))

                op_started()
                ucx_api.tag_recv_nb(
                    worker,
                    msg,
                    msg.nbytes,
                    tag=0,
                    cb_func=_tag_recv_handle,
                    cb_args=(ep, msg),
                )
Esempio n. 5
0
def non_blocking_recv(worker, ep, msg, started_cb, completed_cb, tag=0):
    msg = Array(msg)
    started_cb()
    req = ucx_api.tag_recv_nb(
        worker,
        msg,
        msg.nbytes,
        tag=tag,
        cb_func=non_blocking_handler,
        cb_args=(completed_cb, ),
        ep=ep,
    )
    if req is None:
        completed_cb()
    return req
Esempio n. 6
0
def blocking_recv(worker, ep, msg, tag=0):
    msg = Array(msg)
    finished = [False]
    req = ucx_api.tag_recv_nb(
        worker,
        msg,
        msg.nbytes,
        tag=tag,
        cb_func=blocking_handler,
        cb_args=(finished, ),
        ep=ep,
    )
    if req is not None:
        while not finished[0]:
            worker.progress()