Пример #1
0
async def talk_to_server(ip, port):

    global max_msg_log
    global args

    start_string = "in talk_to_server"
    if args.blind_recv:
        start_string += " + blind recv"
    if args.check_data:
        start_string += " + data validity check"
    print(start_string)

    msg_log = max_msg_log

    ep = ucp.get_endpoint(ip, port)

    send_buffer_region = ucp.BufferRegion()
    send_buffer_region.alloc_host(1 << msg_log)

    send_msg = ucp.Message(send_buffer_region)

    recv_msg = None
    recv_buffer_region = None
    recv_req = None

    if not args.blind_recv:
        recv_buffer_region = ucp.BufferRegion()
        recv_buffer_region.alloc_host(1 << msg_log)
        recv_msg = ucp.Message(recv_buffer_region)

    if args.check_data:
        send_msg.set_mem(1, 1 << msg_log)
        if not args.blind_recv:
            recv_msg.set_mem(1, 1 << msg_log)

    if not args.blind_recv:
        recv_req = await ep.recv(recv_msg, 1 << msg_log)
    else:
        recv_req = await ep.recv_future()

    send_req = await ep.send(send_msg, 1 << msg_log)

    if args.check_data:
        errs = 0
        errs = recv_req.check_mem(0, 1 << msg_log)
        print("num errs: " + str(errs))

    send_buffer_region.free_host()
    if not args.blind_recv:
        recv_buffer_region.free_host()

    ucp.destroy_ep(ep)
    print("done with talk_to_server")
Пример #2
0
async def talk_to_server(ip, port):

    global args
    global max_msg_log

    msg_log = max_msg_log

    start_string = "in talk_to_server"
    if args.blind_recv:
        start_string += " + blind recv"
    if args.use_fast:
        start_string += " + using fast ops"
    print(start_string)

    ep = ucp.get_endpoint(ip, port)
    print("got endpoint")

    send_buffer_region = ucp.BufferRegion()
    send_buffer_region.alloc_cuda(1 << msg_log)

    recv_msg = None
    recv_buffer_region = None
    recv_req = None

    if not args.blind_recv:
        recv_buffer_region = ucp.BufferRegion()
        recv_buffer_region.alloc_cuda(1 << msg_log)
        recv_msg = ucp.Message(recv_buffer_region)

    send_msg = ucp.Message(send_buffer_region)

    if not args.blind_recv:
        if args.use_fast:
            recv_req = await ep.recv_fast(recv_msg, 1 << msg_log)
        else:
            recv_req = await ep.recv(recv_msg, 1 << msg_log)
    else:
        recv_req = await ep.recv_future()

    if args.use_fast:
        send_req = await ep.send_fast(send_msg, 1 << msg_log)
    else:
        send_req = await ep.send(send_msg, 1 << msg_log)

    send_buffer_region.free_cuda()
    if not args.blind_recv:
        recv_buffer_region.free_cuda()
    ucp.destroy_ep(ep)

    print("passed talk_to_server")
Пример #3
0
def test_cupy(dtype):
    cupy = pytest.importorskip('cupy')
    arr = cupy.ones(10, dtype)

    buffer_region = ucp.BufferRegion()
    buffer_region.populate_cuda_ptr(arr)

    result = cupy.asarray(buffer_region)
    cupy.testing.assert_array_equal(result, arr)
Пример #4
0
def test_numba_empty():
    numba = pytest.importorskip("numba")
    import numba.cuda  # noqa

    arr = numba.cuda.device_array(0)
    br = ucp.BufferRegion()
    br.populate_cuda_ptr(arr)

    assert len(br) == 0
    assert br.__cuda_array_interface__["data"][0] == 0
Пример #5
0
def test_set_read():
    obj = memoryview(b'hi')
    buffer_region = ucp.BufferRegion()
    buffer_region.populate_ptr(obj)
    res = memoryview(buffer_region)
    assert res == obj
    assert res.tobytes() == obj.tobytes()

    # our properties
    assert buffer_region.is_cuda == 0
    assert buffer_region.shape[0] == 2
Пример #6
0
def test_numpy(dtype, data):
    np = pytest.importorskip("numpy")
    arr = np.ones(10, dtype)

    buffer_region = ucp.BufferRegion()

    if data:
        buffer_region.populate_ptr(arr.data)
    else:
        buffer_region.populate_ptr(arr.data)

    result = np.asarray(buffer_region)
    np.testing.assert_array_equal(result, arr)
Пример #7
0
def test_alloc_cuda_raises():
    br = ucp.BufferRegion()
    with pytest.raises(ValueError, match=msg):
        br.alloc_cuda(10)
Пример #8
0
def test_free_cuda_raises():
    br = ucp.BufferRegion()
    with pytest.raises(ValueError, match=msg):
        br.free_cuda()