Esempio n. 1
0
def test_feature_flags_mismatch(feature_flag):
    ctx = ucx_api.UCXContext(feature_flags=(feature_flag, ))
    worker = ucx_api.UCXWorker(ctx)
    addr = worker.get_address()
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker, addr, endpoint_error_handling=False)
    msg = Array(bytearray(10))
    if feature_flag != ucx_api.Feature.TAG:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None)
    if feature_flag != ucx_api.Feature.STREAM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_send_nb(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None)
    if feature_flag != ucx_api.Feature.AM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_send_nbx(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_recv_nb(ep, None)
Esempio n. 2
0
 def _recv_handle(recv_obj, exception, ep):
     assert exception is None
     msg = Array(recv_obj)
     ucx_api.am_send_nbx(ep,
                         msg,
                         msg.nbytes,
                         cb_func=_send_handle,
                         cb_args=(msg, ))
Esempio n. 3
0
def blocking_am_send(worker, ep, msg):
    msg = Array(msg)
    finished = [False]
    req = ucx_api.am_send_nbx(
        ep,
        msg,
        msg.nbytes,
        cb_func=blocking_handler,
        cb_args=(finished, ),
    )
    if req is not None:
        while not finished[0]:
            worker.progress()