예제 #1
0
    async def client_node(port):
        ep = await ucp.create_endpoint(
            ucp.get_address(),
            port,
        )
        ep.close_after_n_recv(10)
        for _ in range(10):
            await _shutdown_recv(ep, message_type)
        assert ep.closed()

        ep = await ucp.create_endpoint(
            ucp.get_address(),
            port,
        )
        for _ in range(5):
            await _shutdown_recv(ep, message_type)
        ep.close_after_n_recv(5)
        for _ in range(5):
            await _shutdown_recv(ep, message_type)
        assert ep.closed()

        ep = await ucp.create_endpoint(
            ucp.get_address(),
            port,
        )
        for _ in range(5):
            await _shutdown_recv(ep, message_type)
        ep.close_after_n_recv(10, count_from_ep_creation=True)
        for _ in range(5):
            await _shutdown_recv(ep, message_type)
        assert ep.closed()

        ep = await ucp.create_endpoint(
            ucp.get_address(),
            port,
        )
        for _ in range(10):
            await _shutdown_recv(ep, message_type)

        with pytest.raises(
                ucp.exceptions.UCXError,
                match="`n` cannot be less than current recv_count",
        ):
            ep.close_after_n_recv(5, count_from_ep_creation=True)

        ep.close_after_n_recv(1)
        with pytest.raises(
                ucp.exceptions.UCXError,
                match="close_after_n_recv has already been set to",
        ):
            ep.close_after_n_recv(1)
예제 #2
0
def test_mismatch(server_guarantee_msg_order):

    # We use an exception handle to catch errors raised by the server
    def handle_exception(loop, context):
        msg = str(context.get("exception", context["message"]))
        loop.test_failed = msg.find(loop.error_msg_expected) == -1

    loop = asyncio.get_event_loop()
    loop.set_exception_handler(handle_exception)
    loop.test_failed = False
    loop.error_msg_expected = "Both peers must set guarantee_msg_order identically"

    with pytest.raises(ValueError, match=loop.error_msg_expected):
        lt = ucp.create_listener(
            lambda x: x, guarantee_msg_order=server_guarantee_msg_order)
        loop.run_until_complete(
            ucp.create_endpoint(
                ucp.get_address(),
                lt.port,
                guarantee_msg_order=(not server_guarantee_msg_order),
            ))
    loop.run_until_complete(
        asyncio.sleep(0.1))  # Give the server time to finish

    assert not loop.test_failed, "expected error message not raised by the server"
예제 #3
0
 async def client_node(port):
     ep = await ucp.create_endpoint(
         ucp.get_address(),
         port,
     )
     with pytest.raises(ucp.exceptions.UCXCanceled):
         await asyncio.gather(_shutdown_recv(ep, message_type), ep.close())
예제 #4
0
async def test_send_recv_am(size, blocking_progress_mode, recv_wait, data):
    rndv_thresh = 8192
    ucp.init(
        options={"RNDV_THRESH": str(rndv_thresh)},
        blocking_progress_mode=blocking_progress_mode,
    )

    ucp.register_am_allocator(data["allocator"], data["memory_type"])
    msg = data["generator"](size)

    recv = []
    listener = ucp.create_listener(simple_server(size, recv))
    num_clients = 1
    clients = [
        await ucp.create_endpoint(ucp.get_address(), listener.port)
        for i in range(num_clients)
    ]
    for c in clients:
        if recv_wait:
            # By sleeping here we ensure that the listener's
            # ep.am_recv call will have to wait, rather than return
            # immediately as receive data is already available.
            await asyncio.sleep(1)
        await c.am_send(msg)
    for c in clients:
        await c.close()
    listener.close()

    if data["memory_type"] == "cuda" and msg.nbytes < rndv_thresh:
        # Eager messages are always received on the host, if no host
        # allocator is registered UCX-Py defaults to `bytearray`.
        assert recv[0] == bytearray(msg.get())
    else:
        data["validator"](recv[0], msg)
예제 #5
0
async def test_tag_match():
    msg1 = bytes("msg1", "utf-8")
    msg2 = bytes("msg2", "utf-8")

    async def server_node(ep):
        f1 = ep.send(msg1, tag="msg1")
        await asyncio.sleep(1)  # Let msg1 finish
        f2 = ep.send(msg2, tag="msg2")
        await asyncio.gather(f1, f2)

    lf = ucp.create_listener(server_node)
    ep = await ucp.create_endpoint(ucp.get_address(), lf.port)
    m1, m2 = (bytearray(len(msg1)), bytearray(len(msg2)))
    # May be dropped in favor of `asyncio.create_task` only
    # once Python 3.6 is dropped.
    if hasattr(asyncio, "create_future"):
        f2 = asyncio.create_task(ep.recv(m2, tag="msg2"))
    else:
        f2 = asyncio.ensure_future(ep.recv(m2, tag="msg2"))

    # At this point f2 shouldn't be able to finish because its
    # tag "msg2" doesn't match the servers send tag "msg1"
    done, pending = await asyncio.wait({f2}, timeout=0.01)
    assert f2 in pending
    # "msg1" should be ready
    await ep.recv(m1, tag="msg1")
    assert m1 == msg1
    await f2
    assert m2 == msg2
예제 #6
0
 async def client_node(listener):
     ep = await ucp.create_endpoint(ucp.get_address(), listener.port)
     msg = np.empty(100, dtype=np.int64)
     await ep.recv(msg)
     await ep.recv(msg)
     assert listener.closed() is False
     listener.close()
     assert listener.closed() is True
예제 #7
0
 async def client_node(port):
     ep = await ucp.create_endpoint(
         ucp.get_address(),
         port,
     )
     if server_close_callback is False:
         ep.set_close_callback(_close_callback)
     if server_close_callback is True:
         await ep.close()
예제 #8
0
 async def client_node(port):
     ep = await ucp.create_endpoint(
         ucp.get_address(),
         port,
     )
     if transfer_api == "am":
         await ep.am_send(msg)
     else:
         await ep.send(msg)
예제 #9
0
 async def run():
     server_port = client_queue.get()
     ep = await ucp.create_endpoint(
         ucp.get_address(),
         server_port,
         endpoint_error_handling=endpoint_error_handling,
     )
     msg = np.empty(100, dtype=np.int64)
     await ep.recv(msg)
예제 #10
0
 async def client_node(port):
     ep = await ucp.create_endpoint(
         ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling
     )
     if server_close_callback is False:
         ep.set_close_callback(_close_callback)
     await ep.send(bytearray(b"0" * 10))
     if server_close_callback is True:
         await ep.close()
예제 #11
0
async def test_mismatch():
    def server(ep):
        pass

    lt = ucp.create_listener(server, guarantee_msg_order=True)
    with pytest.raises(
            ValueError,
            match="Both peers must set guarantee_msg_order identically"):
        await ucp.create_endpoint(ucp.get_address(),
                                  lt.port,
                                  guarantee_msg_order=False)

    lt = ucp.create_listener(server, guarantee_msg_order=False)
    with pytest.raises(
            ValueError,
            match="Both peers must set guarantee_msg_order identically"):
        await ucp.create_endpoint(ucp.get_address(),
                                  lt.port,
                                  guarantee_msg_order=True)
예제 #12
0
 async def client_node(listener):
     ep = await ucp.create_endpoint(
         ucp.get_address(),
         listener.port,
     )
     await _shutdown_recv(ep, message_type)
     await _shutdown_recv(ep, message_type)
     assert listener.closed() is False
     listener.close()
     assert listener.closed() is True
예제 #13
0
async def test_get_ucp_worker():
    worker = ucp.get_ucp_worker()
    assert isinstance(worker, int)

    def server(ep):
        assert ep.get_ucp_worker() == worker

    lt = ucp.create_listener(server)
    ep = await ucp.create_endpoint(ucp.get_address(), lt.port)
    assert ep.get_ucp_worker() == worker
예제 #14
0
async def test_get_endpoint():
    def server(ep):
        ucp_ep = ep.get_ucp_endpoint()
        assert isinstance(ucp_ep, int)
        assert ucp_ep > 0

    lt = ucp.create_listener(server)
    ep = await ucp.create_endpoint(ucp.get_address(), lt.port)
    ucp_ep = ep.get_ucp_endpoint()
    assert isinstance(ucp_ep, int)
    assert ucp_ep > 0
예제 #15
0
async def test_zero_port():
    ucp.init()
    listener = ucp.start_listener(talk_to_client,
                                  listener_port=0,
                                  is_coroutine=True)
    assert 0 < listener.port < 2**16

    ip = ucp.get_address()
    await asyncio.gather(listener.coroutine,
                         talk_to_server(ip.encode(), listener.port))
    ucp.fin()
예제 #16
0
    async def client_node(port):
        ep = await ucp.create_endpoint(ucp.get_address(), port)
        ep.close_after_n_recv(10)
        for _ in range(10):
            msg = np.empty(10)
            await ep.recv(msg)
        assert ep.closed()

        ep = await ucp.create_endpoint(ucp.get_address(), port)
        for _ in range(5):
            msg = np.empty(10)
            await ep.recv(msg)
        ep.close_after_n_recv(5)
        for _ in range(5):
            msg = np.empty(10)
            await ep.recv(msg)
        assert ep.closed()

        ep = await ucp.create_endpoint(ucp.get_address(), port)
        for _ in range(5):
            msg = np.empty(10)
            await ep.recv(msg)
        ep.close_after_n_recv(10, count_from_ep_creation=True)
        for _ in range(5):
            msg = np.empty(10)
            await ep.recv(msg)
        assert ep.closed()

        ep = await ucp.create_endpoint(ucp.get_address(), port)
        for _ in range(10):
            msg = np.empty(10)
            await ep.recv(msg)

        with pytest.raises(ucp.exceptions.UCXError,
                           match="`n` cannot be less than current recv_count"):
            ep.close_after_n_recv(5, count_from_ep_creation=True)

        ep.close_after_n_recv(1)
        with pytest.raises(ucp.exceptions.UCXError,
                           match="close_after_n_recv has already been set to"):
            ep.close_after_n_recv(1)
예제 #17
0
async def test_reset():
    reset = ResetAfterN(2)

    def server(ep):
        ep.abort()
        reset()

    lt = ucp.create_listener(server)
    ep = await ucp.create_endpoint(ucp.get_address(), lt.port)
    del lt
    del ep
    reset()
예제 #18
0
async def test_send_recv_bytes(size, blocking_progress_mode):
    ucp.init(blocking_progress_mode=blocking_progress_mode)

    msg = bytearray(b"m" * size)
    msg_size = np.array([len(msg)], dtype=np.uint64)

    listener = ucp.create_listener(make_echo_server(lambda n: bytearray(n)))
    client = await ucp.create_endpoint(ucp.get_address(), listener.port)
    await client.send(msg_size)
    await client.send(msg)
    resp = bytearray(size)
    await client.recv(resp)
    assert resp == msg
예제 #19
0
async def test_listener_del():
    """The client delete the listener"""
    async def server_node(ep):
        await ep.send(np.arange(100, dtype=np.int64))
        await ep.send(np.arange(100, dtype=np.int64))

    listener = ucp.create_listener(server_node)
    ep = await ucp.create_endpoint(ucp.get_address(), listener.port)
    msg = np.empty(100, dtype=np.int64)
    await ep.recv(msg)
    assert listener.closed() is False
    del listener
    await ep.recv(msg)
예제 #20
0
async def test_send_recv_numpy(size, dtype, blocking_progress_mode):
    ucp.init(blocking_progress_mode=blocking_progress_mode)

    msg = np.arange(size, dtype=dtype)
    msg_size = np.array([msg.nbytes], dtype=np.uint64)

    listener = ucp.create_listener(
        make_echo_server(lambda n: np.empty(n, dtype=np.uint8)))
    client = await ucp.create_endpoint(ucp.get_address(), listener.port)
    await client.send(msg_size)
    await client.send(msg)
    resp = np.empty_like(msg)
    await client.recv(resp)
    np.testing.assert_array_equal(resp, msg)
예제 #21
0
async def test_send_recv_obj(blocking_progress_mode):
    ucp.init(blocking_progress_mode=blocking_progress_mode)

    async def echo_obj_server(ep):
        obj = await ep.recv_obj()
        await ep.send_obj(obj)

    listener = ucp.create_listener(echo_obj_server)
    client = await ucp.create_endpoint(ucp.get_address(), listener.port)

    msg = bytearray(b"hello")
    await client.send_obj(msg)
    got = await client.recv_obj()
    assert msg == got
예제 #22
0
async def test_ep_still_in_scope_error():
    reset = ResetAfterN(2)

    def server(ep):
        ep.abort()
        reset()

    lt = ucp.create_listener(server)
    ep = await ucp.create_endpoint(ucp.get_address(), lt.port)
    del lt
    with pytest.raises(ucp.exceptions.UCXError, match="_ucp_endpoint"):
        ucp.reset()
    ep.abort()
    ucp.reset()
예제 #23
0
async def echo_pair(cuda_info=None):
    ucp.init()
    loop = asyncio.get_event_loop()
    listener = ucp.start_listener(ucp.make_server(cuda_info),
                                  is_coroutine=True)
    t = loop.create_task(listener.coroutine)
    address = ucp.get_address()
    client = await ucp.get_endpoint(address.encode(), listener.port)
    try:
        yield listener, client
    finally:
        ucp.destroy_ep(client)
        await t
        ucp.fin()
예제 #24
0
async def test_send_recv_error(blocking_progress_mode):
    ucp.init(blocking_progress_mode=blocking_progress_mode)

    async def say_hey_server(ep):
        await ep.send(bytearray(b"Hey"))

    listener = ucp.create_listener(say_hey_server)
    client = await ucp.create_endpoint(ucp.get_address(), listener.port)

    msg = bytearray(100)
    with pytest.raises(
            ucp.exceptions.UCXError,
            match=r"length mismatch: 3 \(got\) != 100 \(expected\)"):
        await client.recv(msg)
예제 #25
0
async def test_send_recv_numba(size, dtype, blocking_progress_mode):
    ucp.init(blocking_progress_mode=blocking_progress_mode)
    cuda = pytest.importorskip("numba.cuda")

    ary = np.arange(size, dtype=dtype)
    msg = cuda.to_device(ary)
    msg_size = np.array([msg.nbytes], dtype=np.uint64)
    listener = ucp.create_listener(
        make_echo_server(lambda n: cuda.device_array((n, ), dtype=np.uint8)))
    client = await ucp.create_endpoint(ucp.get_address(), listener.port)
    await client.send(msg_size)
    await client.send(msg)
    resp = cuda.device_array_like(msg)
    await client.recv(resp)
    np.testing.assert_array_equal(np.array(resp), np.array(msg))
예제 #26
0
async def test_lt_still_in_scope_error():
    reset = ResetAfterN(2)

    def server(ep):
        ep.abort()
        reset()

    lt = ucp.create_listener(server)
    ep = await ucp.create_endpoint(ucp.get_address(), lt.port)
    del ep
    with pytest.raises(ucp.exceptions.UCXError, match="ucp._libs.core._Listener"):
        ucp.reset()

    lt.close()
    ucp.reset()
예제 #27
0
async def test_listener_del(message_type):
    """The client delete the listener"""
    async def server_node(ep):
        await _shutdown_send(ep, message_type)
        await _shutdown_send(ep, message_type)

    listener = ucp.create_listener(server_node, )
    ep = await ucp.create_endpoint(
        ucp.get_address(),
        listener.port,
    )
    await _shutdown_recv(ep, message_type)
    assert listener.closed() is False
    del listener
    await _shutdown_recv(ep, message_type)
async def tmp():
    addr = ucp.get_address().encode('utf-8')
    ep1 = ucp.get_endpoint(addr, 13337)
    ep2 = ucp.get_endpoint(addr, 13338)

    await ep1.send_obj(b'hi')
    print("past send1")
    recv_ft1 = ep1.recv_future()
    await recv_ft1
    print("past recv1")

    await ep2.send_obj(b'hi')
    recv_ft2 = ep2.recv_future()
    await recv_ft2
    print("past recv2")
예제 #29
0
async def test_send_recv_obj_numpy(blocking_progress_mode):
    ucp.init(blocking_progress_mode=blocking_progress_mode)

    allocator = functools.partial(np.empty, dtype=np.uint8)

    async def echo_obj_server(ep):
        obj = await ep.recv_obj(allocator=allocator)
        await ep.send_obj(obj)

    listener = ucp.create_listener(echo_obj_server)
    client = await ucp.create_endpoint(ucp.get_address(), listener.port)

    msg = bytearray(b"hello")
    await client.send_obj(msg)
    got = await client.recv_obj(allocator=allocator)
    assert msg == got
예제 #30
0
 async def client_node(port):
     ep = await ucp.create_endpoint(ucp.get_address(), port)
     if transfer_api == "am":
         with pytest.raises(
             ucp.exceptions.UCXCanceled,
             match="am_recv",
         ):
             await ep.am_recv()
     else:
         with pytest.raises(
             ucp.exceptions.UCXCanceled,
             match="Recv.*tag",
         ):
             msg = bytearray(1)
             await ep.recv(msg)
     await ep.close()