Ejemplo n.º 1
0
def test_Array_contiguous_builtins(buffer):
    mv = memoryview(buffer)
    arr = Array(buffer)
    assert arr.c_contiguous == mv.c_contiguous
    assert arr.f_contiguous == mv.f_contiguous
    assert arr.contiguous == mv.contiguous

    mv2 = memoryview(buffer)[::2]
    if mv2:
        arr2 = Array(mv2)
        assert arr2.c_contiguous == mv2.c_contiguous
        assert arr2.f_contiguous == mv2.f_contiguous
        assert arr2.contiguous == mv2.contiguous
Ejemplo n.º 2
0
def test_feature_flags_mismatch(feature_flag):
    ctx = ucx_api.UCXContext(feature_flags=(feature_flag, ))
    worker = ucx_api.UCXWorker(ctx)
    addr = worker.get_address()
    ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker, addr, endpoint_error_handling=False)
    msg = Array(bytearray(10))
    if feature_flag != ucx_api.Feature.TAG:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_send_nb(ep, msg, msg.nbytes, 0, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.TAG`"):
            ucx_api.tag_recv_nb(worker, msg, msg.nbytes, 0, None)
    if feature_flag != ucx_api.Feature.STREAM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_send_nb(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.STREAM`"):
            ucx_api.stream_recv_nb(ep, msg, msg.nbytes, None)
    if feature_flag != ucx_api.Feature.AM:
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_send_nbx(ep, msg, msg.nbytes, None)
        with pytest.raises(
                ValueError,
                match="UCXContext must be created with `Feature.AM`"):
            ucx_api.am_recv_nb(ep, None)
Ejemplo n.º 3
0
def test_Array_ndarray_contiguous(xp, shape, dtype, strides):
    xp, arr, iface = create_array(xp, shape, dtype, strides)
    arr2 = Array(arr)

    assert arr2.c_contiguous == arr.flags.c_contiguous
    assert arr2.f_contiguous == arr.flags.f_contiguous
    assert arr2.contiguous == (arr.flags.c_contiguous or arr.flags.f_contiguous)
Ejemplo n.º 4
0
 def _recv_handle(recv_obj, exception, ep):
     assert exception is None
     msg = Array(recv_obj)
     ucx_api.am_send_nbx(ep,
                         msg,
                         msg.nbytes,
                         cb_func=_send_handle,
                         cb_args=(msg, ))
Ejemplo n.º 5
0
 def _listener_handler(conn_request):
     global ep
     ep = ucx_api.UCXEndpoint.create_from_conn_request(
         worker, conn_request, endpoint_error_handling=endpoint_error_handling,
     )
     msg = Array(bytearray(msg_size))
     ucx_api.tag_recv_nb(
         worker, msg, msg.nbytes, tag=0, cb_func=_recv_handle, cb_args=(ep, msg)
     )
Ejemplo n.º 6
0
def blocking_am_send(worker, ep, msg):
    msg = Array(msg)
    finished = [False]
    req = ucx_api.am_send_nbx(
        ep,
        msg,
        msg.nbytes,
        cb_func=blocking_handler,
        cb_args=(finished, ),
    )
    if req is not None:
        while not finished[0]:
            worker.progress()
Ejemplo n.º 7
0
    def _listener_handler(conn_request, msg):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup before starting to transfer data
        if args.enable_am is True:
            ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
        else:
            wireup = Array(bytearray(len(WireupMessage)))
            op_started()
            ucx_api.tag_recv_nb(
                worker,
                wireup,
                wireup.nbytes,
                tag=0,
                cb_func=_tag_recv_handle,
                cb_args=(ep, wireup),
            )

        for i in range(args.n_iter + args.n_warmup_iter):
            if args.enable_am is True:
                ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
            else:
                if not args.reuse_alloc:
                    msg = Array(xp.zeros(args.n_bytes, dtype="u1"))

                op_started()
                ucx_api.tag_recv_nb(
                    worker,
                    msg,
                    msg.nbytes,
                    tag=0,
                    cb_func=_tag_recv_handle,
                    cb_args=(ep, msg),
                )
Ejemplo n.º 8
0
def non_blocking_send(worker, ep, msg, started_cb, completed_cb, tag=0):
    msg = Array(msg)
    started_cb()
    req = ucx_api.tag_send_nb(
        ep,
        msg,
        msg.nbytes,
        tag=tag,
        cb_func=non_blocking_handler,
        cb_args=(completed_cb, ),
    )
    if req is None:
        completed_cb()
    return req
Ejemplo n.º 9
0
def blocking_recv(worker, ep, msg, tag=0):
    msg = Array(msg)
    finished = [False]
    req = ucx_api.tag_recv_nb(
        worker,
        msg,
        msg.nbytes,
        tag=tag,
        cb_func=blocking_handler,
        cb_args=(finished, ),
        ep=ep,
    )
    if req is not None:
        while not finished[0]:
            worker.progress()
Ejemplo n.º 10
0
def _client_cancel(queue, transfer_api):
    """Client that connects to server and waits for messages to be received,
    because the server closes without sending anything, the messages will
    trigger cancelation.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)
    port = queue.get()
    ep = ucx_api.UCXEndpoint.create(
        worker,
        get_address(),
        port,
        endpoint_error_handling=True,
    )

    ret = [None]

    if transfer_api == "am":
        ucx_api.am_recv_nb(ep, cb_func=_handler, cb_args=(ret, ))

        match_msg = ".*am_recv.*"
    else:
        msg = Array(bytearray(1))
        ucx_api.tag_recv_nb(worker,
                            msg,
                            msg.nbytes,
                            tag=0,
                            cb_func=_handler,
                            cb_args=(ret, ),
                            ep=ep)

        match_msg = ".*tag_recv_nb.*"

    while ep.is_alive():
        worker.progress()

    canceled = worker.cancel_inflight_messages()

    while ret[0] is None:
        worker.progress()

    assert canceled == 1
    assert isinstance(ret[0], UCXCanceled)
    assert re.match(match_msg, ret[0].args[0])
Ejemplo n.º 11
0
def test_Array_nbytes_builtins(buffer):
    arr = Array(buffer)
    mv = memoryview(buffer)
    assert arr.nbytes == mv.nbytes
Ejemplo n.º 12
0
def test_Array_strides_builtins(buffer):
    arr = Array(buffer)
    mv = memoryview(buffer)
    assert arr.strides == mv.strides
Ejemplo n.º 13
0
def test_Array_shape_builtins(buffer):
    arr = Array(buffer)
    mv = memoryview(buffer)
    assert arr.shape == mv.shape
Ejemplo n.º 14
0
def test_Array_ndim_builtins(buffer):
    arr = Array(buffer)
    mv = memoryview(buffer)
    assert arr.ndim == mv.ndim
Ejemplo n.º 15
0
def test_Array_itemsize_builtins(buffer):
    arr = Array(buffer)
    mv = memoryview(buffer)
    assert arr.itemsize == mv.itemsize
Ejemplo n.º 16
0
def test_Array_obj_builtins(buffer):
    arr = Array(buffer)
    mv = memoryview(buffer)
    assert arr.obj is mv.obj
Ejemplo n.º 17
0
def test_Array_readonly_builtins(buffer):
    arr = Array(buffer)
    mv = memoryview(buffer)
    assert arr.readonly == mv.readonly
Ejemplo n.º 18
0
def test_Array_ptr_builtins(buffer):
    arr = Array(buffer)
    assert arr.ptr != 0
Ejemplo n.º 19
0
def server(queue, args):
    if args.server_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.server_cpu_affinity])

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.server_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.server_dev],
        )
        xp.cuda.runtime.setDevice(args.server_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(
        feature_flags=(
            ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
        )
    )
    worker = ucx_api.UCXWorker(ctx)

    register_am_allocators(args, worker)

    # A reference to listener's endpoint is stored to prevent it from going
    # out of scope too early.
    ep = None

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    def _send_handle(request, exception, msg):
        # Notice, we pass `msg` to the handler in order to make sure
        # it doesn't go out of scope prematurely.
        assert exception is None
        op_completed()

    def _tag_recv_handle(request, exception, ep, msg):
        assert exception is None
        req = ucx_api.tag_send_nb(
            ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,)
        )
        if req is None:
            op_completed()

    def _am_recv_handle(recv_obj, exception, ep):
        assert exception is None
        msg = Array(recv_obj)
        ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

    def _listener_handler(conn_request, msg):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

        # Wireup before starting to transfer data
        if args.enable_am is True:
            ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
        else:
            wireup = Array(bytearray(len(WireupMessage)))
            op_started()
            ucx_api.tag_recv_nb(
                worker,
                wireup,
                wireup.nbytes,
                tag=0,
                cb_func=_tag_recv_handle,
                cb_args=(ep, wireup),
            )

        for i in range(args.n_iter + args.n_warmup_iter):
            if args.enable_am is True:
                ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
            else:
                if not args.reuse_alloc:
                    msg = Array(xp.zeros(args.n_bytes, dtype="u1"))

                op_started()
                ucx_api.tag_recv_nb(
                    worker,
                    msg,
                    msg.nbytes,
                    tag=0,
                    cb_func=_tag_recv_handle,
                    cb_args=(ep, msg),
                )

    if not args.enable_am and args.reuse_alloc:
        msg = Array(xp.zeros(args.n_bytes, dtype="u1"))
    else:
        msg = None

    listener = ucx_api.UCXListener(
        worker=worker, port=args.port or 0, cb_func=_listener_handler, cb_args=(msg,)
    )
    queue.put(listener.port)

    while outstanding[0] == 0:
        worker.progress()

    # +1 to account for wireup message
    if args.delay_progress:
        while finished[0] < args.n_iter + args.n_warmup_iter + 1 and (
            outstanding[0] >= args.max_outstanding
            or finished[0] + args.max_outstanding
            >= args.n_iter + args.n_warmup_iter + 1
        ):
            worker.progress()
    else:
        while finished[0] != args.n_iter + args.n_warmup_iter + 1:
            worker.progress()
Ejemplo n.º 20
0
def test_Array_ndarray_strides(xp, shape, dtype, strides):
    xp, arr, iface = create_array(xp, shape, dtype, strides)
    arr2 = Array(arr)

    assert arr2.strides == arr.strides
Ejemplo n.º 21
0
def test_Array_ndarray_is_cuda(xp, shape, dtype, strides):
    xp, arr, iface = create_array(xp, shape, dtype, strides)
    arr2 = Array(arr)

    is_cuda = xp.__name__ == "cupy"
    assert arr2.cuda == is_cuda
Ejemplo n.º 22
0
def test_Array_ndarray_ptr(xp, shape, dtype, strides):
    xp, arr, iface = create_array(xp, shape, dtype, strides)
    arr2 = Array(arr)

    assert arr2.ptr == iface["data"][0]