async def talk_to_server(ip, port): global max_msg_log global args start_string = "in talk_to_server" if args.blind_recv: start_string += " + blind recv" if args.check_data: start_string += " + data validity check" print(start_string) msg_log = max_msg_log ep = ucp.get_endpoint(ip, port) send_buffer_region = ucp.BufferRegion() send_buffer_region.alloc_host(1 << msg_log) send_msg = ucp.Message(send_buffer_region) recv_msg = None recv_buffer_region = None recv_req = None if not args.blind_recv: recv_buffer_region = ucp.BufferRegion() recv_buffer_region.alloc_host(1 << msg_log) recv_msg = ucp.Message(recv_buffer_region) if args.check_data: send_msg.set_mem(1, 1 << msg_log) if not args.blind_recv: recv_msg.set_mem(1, 1 << msg_log) if not args.blind_recv: recv_req = await ep.recv(recv_msg, 1 << msg_log) else: recv_req = await ep.recv_future() send_req = await ep.send(send_msg, 1 << msg_log) if args.check_data: errs = 0 errs = recv_req.check_mem(0, 1 << msg_log) print("num errs: " + str(errs)) send_buffer_region.free_host() if not args.blind_recv: recv_buffer_region.free_host() ucp.destroy_ep(ep) print("done with talk_to_server")
async def talk_to_server(ip, port): global args global max_msg_log msg_log = max_msg_log start_string = "in talk_to_server" if args.blind_recv: start_string += " + blind recv" if args.use_fast: start_string += " + using fast ops" print(start_string) ep = ucp.get_endpoint(ip, port) print("got endpoint") send_buffer_region = ucp.BufferRegion() send_buffer_region.alloc_cuda(1 << msg_log) recv_msg = None recv_buffer_region = None recv_req = None if not args.blind_recv: recv_buffer_region = ucp.BufferRegion() recv_buffer_region.alloc_cuda(1 << msg_log) recv_msg = ucp.Message(recv_buffer_region) send_msg = ucp.Message(send_buffer_region) if not args.blind_recv: if args.use_fast: recv_req = await ep.recv_fast(recv_msg, 1 << msg_log) else: recv_req = await ep.recv(recv_msg, 1 << msg_log) else: recv_req = await ep.recv_future() if args.use_fast: send_req = await ep.send_fast(send_msg, 1 << msg_log) else: send_req = await ep.send(send_msg, 1 << msg_log) send_buffer_region.free_cuda() if not args.blind_recv: recv_buffer_region.free_cuda() ucp.destroy_ep(ep) print("passed talk_to_server")
def test_cupy(dtype): cupy = pytest.importorskip('cupy') arr = cupy.ones(10, dtype) buffer_region = ucp.BufferRegion() buffer_region.populate_cuda_ptr(arr) result = cupy.asarray(buffer_region) cupy.testing.assert_array_equal(result, arr)
def test_numba_empty(): numba = pytest.importorskip("numba") import numba.cuda # noqa arr = numba.cuda.device_array(0) br = ucp.BufferRegion() br.populate_cuda_ptr(arr) assert len(br) == 0 assert br.__cuda_array_interface__["data"][0] == 0
def test_set_read(): obj = memoryview(b'hi') buffer_region = ucp.BufferRegion() buffer_region.populate_ptr(obj) res = memoryview(buffer_region) assert res == obj assert res.tobytes() == obj.tobytes() # our properties assert buffer_region.is_cuda == 0 assert buffer_region.shape[0] == 2
def test_numpy(dtype, data): np = pytest.importorskip("numpy") arr = np.ones(10, dtype) buffer_region = ucp.BufferRegion() if data: buffer_region.populate_ptr(arr.data) else: buffer_region.populate_ptr(arr.data) result = np.asarray(buffer_region) np.testing.assert_array_equal(result, arr)
def test_alloc_cuda_raises(): br = ucp.BufferRegion() with pytest.raises(ValueError, match=msg): br.alloc_cuda(10)
def test_free_cuda_raises(): br = ucp.BufferRegion() with pytest.raises(ValueError, match=msg): br.free_cuda()