Beispiel #1
0
def test_close_callback(server_close_callback):
    endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0)

    queue = mp.Queue()
    server = mp.Process(
        target=_server,
        args=(queue, endpoint_error_handling, server_close_callback),
    )
    server.start()
    port = queue.get()
    client = mp.Process(
        target=_client,
        args=(port, endpoint_error_handling, server_close_callback),
    )
    client.start()
    client.join(timeout=10)
    server.join(timeout=10)
    assert client.exitcode == 0
    assert server.exitcode == 0
Beispiel #2
0
def test_server_client(msg_size):
    endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0)

    put_queue, get_queue = mp.Queue(), mp.Queue()
    server = mp.Process(
        target=_echo_server,
        args=(put_queue, get_queue, msg_size, endpoint_error_handling),
    )
    server.start()
    port = get_queue.get()
    client = mp.Process(
        target=_echo_client, args=(msg_size, port, endpoint_error_handling)
    )
    client.start()
    client.join(timeout=10)
    assert not client.exitcode
    put_queue.put("Finished")
    server.join(timeout=10)
    assert not server.exitcode
Beispiel #3
0
    def _listener_handler(conn_request, msg):
        global ep
        ep = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0),
        )

        # Wireup before starting to transfer data
        if args.enable_am is True:
            ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep, ))
        else:
            wireup = Array(bytearray(len(WireupMessage)))
            op_started()
            ucx_api.tag_recv_nb(
                worker,
                wireup,
                wireup.nbytes,
                tag=0,
                cb_func=_tag_recv_handle,
                cb_args=(ep, wireup),
            )

        for i in range(args.n_iter):
            if args.enable_am is True:
                ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep, ))
            else:
                if not args.reuse_alloc:
                    msg = Array(xp.zeros(args.n_bytes, dtype="u1"))

                op_started()
                ucx_api.tag_recv_nb(
                    worker,
                    msg,
                    msg.nbytes,
                    tag=0,
                    cb_func=_tag_recv_handle,
                    cb_args=(ep, msg),
                )
Beispiel #4
0
    os.environ["UCX_SEG_SIZE"] = "2M"
    ctx = ucx_api.UCXContext()
    config = ctx.get_config()
    assert config["SEG_SIZE"] == os.environ["UCX_SEG_SIZE"]


def test_init_options():
    os.environ["UCX_SEG_SIZE"] = "2M"  # Should be ignored
    options = {"SEG_SIZE": "3M"}
    ctx = ucx_api.UCXContext(options)
    config = ctx.get_config()
    assert config["SEG_SIZE"] == options["SEG_SIZE"]


@pytest.mark.skipif(
    ucx_api.get_ucx_version() >= (1, 12, 0),
    reason="Beginning with UCX >= 1.12, it's only possible to validate "
    "UCP options but not options from other modules such as UCT. "
    "See https://github.com/openucx/ucx/issues/7519.",
)
def test_init_unknown_option():
    options = {"UNKNOWN_OPTION": "3M"}
    with pytest.raises(UCXConfigError):
        ucx_api.UCXContext(options)


def test_init_invalid_option():
    options = {"SEG_SIZE": "invalid-size"}
    with pytest.raises(UCXConfigError):
        ucx_api.UCXContext(options)
Beispiel #5
0
def get_endpoint_error_handling_default():
    return ucx_api.get_ucx_version() >= (1, 10, 0)
Beispiel #6
0
def client(queue, port, server_address, args):
    if args.client_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.client_cpu_affinity])

    import numpy as np

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.client_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.client_dev],
        )
        xp.cuda.runtime.setDevice(args.client_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(feature_flags=(
        ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
    ))
    worker = ucx_api.UCXWorker(ctx)
    register_am_allocators(args, worker)
    ep = ucx_api.UCXEndpoint.create(
        worker,
        server_address,
        port,
        endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0),
    )

    send_msg = xp.arange(args.n_bytes, dtype="u1")
    if args.reuse_alloc:
        recv_msg = xp.zeros(args.n_bytes, dtype="u1")

    if args.enable_am:
        blocking_am_send(worker, ep, send_msg)
        blocking_am_recv(worker, ep)
    else:
        wireup_recv = bytearray(len(WireupMessage))
        blocking_send(worker, ep, WireupMessage)
        blocking_recv(worker, ep, wireup_recv)

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def maybe_progress():
        while outstanding[0] >= args.max_outstanding:
            worker.progress()

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    if args.cuda_profile:
        xp.cuda.profiler.start()

    times = []
    for i in range(args.n_iter):
        start = clock()

        if args.enable_am:
            blocking_am_send(worker, ep, send_msg)
            blocking_am_recv(worker, ep)
        else:
            if not args.reuse_alloc:
                recv_msg = xp.zeros(args.n_bytes, dtype="u1")

            if args.delay_progress:
                maybe_progress()
                non_blocking_send(worker, ep, send_msg, op_started,
                                  op_completed)
                maybe_progress()
                non_blocking_recv(worker, ep, recv_msg, op_started,
                                  op_completed)
            else:
                blocking_send(worker, ep, send_msg)
                blocking_recv(worker, ep, recv_msg)

        stop = clock()
        times.append(stop - start)

    if args.delay_progress:
        while finished[0] != 2 * args.n_iter:
            worker.progress()

    if args.cuda_profile:
        xp.cuda.profiler.stop()

    assert len(times) == args.n_iter
    delay_progress_str = (f"True ({args.max_outstanding})"
                          if args.delay_progress is True else "False")
    print("Roundtrip benchmark")
    print("--------------------------")
    print(f"n_iter          | {args.n_iter}")
    print(f"n_bytes         | {format_bytes(args.n_bytes)}")
    print(f"object          | {args.object_type}")
    print(f"reuse alloc     | {args.reuse_alloc}")
    print(f"transfer API    | {'AM' if args.enable_am else 'TAG'}")
    print(f"delay progress  | {delay_progress_str}")
    print(f"UCX_TLS         | {ucp.get_config()['TLS']}")
    print(f"UCX_NET_DEVICES | {ucp.get_config()['NET_DEVICES']}")
    print("==========================")
    if args.object_type == "numpy":
        print("Device(s)       | CPU-only")
        s_aff = (args.server_cpu_affinity
                 if args.server_cpu_affinity >= 0 else "affinity not set")
        c_aff = (args.client_cpu_affinity
                 if args.client_cpu_affinity >= 0 else "affinity not set")
        print(f"Server CPU      | {s_aff}")
        print(f"Client CPU      | {c_aff}")
    else:
        print(f"Device(s)       | {args.server_dev}, {args.client_dev}")
    avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times))
    med = format_bytes(2 * args.n_bytes / np.median(times))
    print(f"Average         | {avg}/s")
    print(f"Median          | {med}/s")
    if not args.no_detailed_report:
        print("--------------------------")
        print("Iterations")
        print("--------------------------")
        for i, t in enumerate(times):
            ts = format_bytes(2 * args.n_bytes / t)
            ts = (" " * (9 - len(ts))) + ts
            print("%03d         |%s/s" % (i, ts))