Exemplo n.º 1
0
def _test_peer_communication_tag(queue, rank, msg_size):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG, ))
    worker = ucx_api.UCXWorker(ctx)
    queue.put((rank, worker.get_address()))
    right_rank, right_address = queue.get()
    left_rank, left_address = queue.get()

    right_ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        right_address,
        endpoint_error_handling=True,
    )
    left_ep = ucx_api.UCXEndpoint.create_from_worker_address(
        worker,
        left_address,
        endpoint_error_handling=True,
    )
    recv_msg = bytearray(msg_size)
    if rank == 0:
        send_msg = bytes(os.urandom(msg_size))
        blocking_send(worker, right_ep, send_msg, right_rank)
        blocking_recv(worker, left_ep, recv_msg, rank)
        assert send_msg == recv_msg
    else:
        blocking_recv(worker, left_ep, recv_msg, rank)
        blocking_send(worker, right_ep, recv_msg, right_rank)
Exemplo n.º 2
0
def _server_probe(queue, transfer_api):
    """Server that probes and receives message after client disconnected.

    Note that since it is illegal to call progress() in callback functions,
    we keep a reference to the endpoint after the listener callback has
    terminated, this way we can progress even after Python blocking calls.
    """
    feature_flags = (ucx_api.Feature.AM
                     if transfer_api == "am" else ucx_api.Feature.TAG, )
    ctx = ucx_api.UCXContext(feature_flags=feature_flags)
    worker = ucx_api.UCXWorker(ctx)

    # Keep endpoint to be used from outside the listener callback
    ep = [None]

    def _listener_handler(conn_request):
        ep[0] = ucx_api.UCXEndpoint.create_from_conn_request(
            worker,
            conn_request,
            endpoint_error_handling=True,
        )

    listener = ucx_api.UCXListener(worker=worker,
                                   port=0,
                                   cb_func=_listener_handler)
    queue.put(listener.port),

    while ep[0] is None:
        worker.progress()

    ep = ep[0]

    # Ensure wireup and inform client before it can disconnect
    if transfer_api == "am":
        wireup = blocking_am_recv(worker, ep)
    else:
        wireup = bytearray(len(WireupMessage))
        blocking_recv(worker, ep, wireup)
    queue.put("wireup completed")

    # Ensure client has disconnected -- endpoint is not alive anymore
    while ep.is_alive() is True:
        worker.progress()

    # Probe/receive message even after the remote endpoint has disconnected
    if transfer_api == "am":
        while ep.am_probe() is False:
            worker.progress()
        received = blocking_am_recv(worker, ep)
    else:
        while worker.tag_probe(0) is False:
            worker.progress()
        received = bytearray(len(DataMessage))
        blocking_recv(worker, ep, received)

    assert wireup == WireupMessage
    assert received == DataMessage
Exemplo n.º 3
0
def _echo_client(msg_size, port, endpoint_error_handling):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
    worker = ucx_api.UCXWorker(ctx)
    ep = ucx_api.UCXEndpoint.create(
        worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
    )
    send_msg = bytes(os.urandom(msg_size))
    recv_msg = bytearray(msg_size)
    blocking_send(worker, ep, send_msg)
    blocking_recv(worker, ep, recv_msg)
    assert send_msg == recv_msg
Exemplo n.º 4
0
def _test_peer_communication_rma(queue, rank, msg_size):
    ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.RMA,
                                            ucx_api.Feature.TAG))
    worker = ucx_api.UCXWorker(ctx)
    self_address = worker.get_address()
    mem_handle = ctx.alloc(msg_size)
    self_base = mem_handle.address
    self_prkey = mem_handle.pack_rkey()

    self_ep, self_mem = _rma_setup(worker, self_address, self_prkey, self_base,
                                   msg_size)
    send_msg = bytes(repeat(rank, msg_size))
    if not self_mem.put_nbi(send_msg):
        blocking_flush(self_ep)

    queue.put((rank, self_address, self_prkey, self_base))
    right_rank, right_address, right_prkey, right_base = queue.get()
    left_rank, left_address, left_prkey, left_base = queue.get()

    right_ep, right_mem = _rma_setup(worker, right_address, right_prkey,
                                     right_base, msg_size)
    right_msg = bytearray(msg_size)
    right_mem.get_nbi(right_msg)

    left_ep, left_mem = _rma_setup(worker, left_address, left_prkey, left_base,
                                   msg_size)
    left_msg = bytearray(msg_size)
    left_mem.get_nbi(left_msg)

    blocking_flush(worker)
    assert left_msg == bytes(repeat(left_rank, msg_size))
    assert right_msg == bytes(repeat(right_rank, msg_size))

    # We use the blocking tag send/recv as a barrier implementation
    recv_msg = bytearray(8)
    if rank == 0:
        send_msg = bytes(os.urandom(8))
        blocking_send(worker, right_ep, send_msg, right_rank)
        blocking_recv(worker, left_ep, recv_msg, rank)
    else:
        blocking_recv(worker, left_ep, recv_msg, rank)
        blocking_send(worker, right_ep, recv_msg, right_rank)
Exemplo n.º 5
0
def client(queue, port, server_address, args):
    if args.client_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.client_cpu_affinity])

    import numpy as np

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.client_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.client_dev],
        )
        xp.cuda.runtime.setDevice(args.client_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(
        feature_flags=(
            ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
        )
    )
    worker = ucx_api.UCXWorker(ctx)
    register_am_allocators(args, worker)
    ep = ucx_api.UCXEndpoint.create(
        worker,
        server_address,
        port,
        endpoint_error_handling=True,
    )

    send_msg = xp.arange(args.n_bytes, dtype="u1")
    if args.reuse_alloc:
        recv_msg = xp.zeros(args.n_bytes, dtype="u1")

    if args.enable_am:
        blocking_am_send(worker, ep, send_msg)
        blocking_am_recv(worker, ep)
    else:
        wireup_recv = bytearray(len(WireupMessage))
        blocking_send(worker, ep, WireupMessage)
        blocking_recv(worker, ep, wireup_recv)

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def maybe_progress():
        while outstanding[0] >= args.max_outstanding:
            worker.progress()

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    if args.cuda_profile:
        xp.cuda.profiler.start()

    times = []
    for i in range(args.n_iter + args.n_warmup_iter):
        start = clock()

        if args.enable_am:
            blocking_am_send(worker, ep, send_msg)
            blocking_am_recv(worker, ep)
        else:
            if not args.reuse_alloc:
                recv_msg = xp.zeros(args.n_bytes, dtype="u1")

            if args.delay_progress:
                maybe_progress()
                non_blocking_send(worker, ep, send_msg, op_started, op_completed)
                maybe_progress()
                non_blocking_recv(worker, ep, recv_msg, op_started, op_completed)
            else:
                blocking_send(worker, ep, send_msg)
                blocking_recv(worker, ep, recv_msg)

        stop = clock()
        if i >= args.n_warmup_iter:
            times.append(stop - start)

    if args.delay_progress:
        while finished[0] != 2 * (args.n_iter + args.n_warmup_iter):
            worker.progress()

    if args.cuda_profile:
        xp.cuda.profiler.stop()

    assert len(times) == args.n_iter
    bw_avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times))
    bw_med = format_bytes(2 * args.n_bytes / np.median(times))
    lat_avg = int(sum(times) * 1e9 / (2 * args.n_iter))
    lat_med = int(np.median(times) * 1e9 / 2)

    delay_progress_str = (
        f"True ({args.max_outstanding})" if args.delay_progress is True else "False"
    )

    print("Roundtrip benchmark")
    print_separator(separator="=")
    print_key_value(key="Iterations", value=f"{args.n_iter}")
    print_key_value(key="Bytes", value=f"{format_bytes(args.n_bytes)}")
    print_key_value(key="Object type", value=f"{args.object_type}")
    print_key_value(key="Reuse allocation", value=f"{args.reuse_alloc}")
    print_key_value(key="Transfer API", value=f"{'AM' if args.enable_am else 'TAG'}")
    print_key_value(key="Delay progress", value=f"{delay_progress_str}")
    print_key_value(key="UCX_TLS", value=f"{ucp.get_config()['TLS']}")
    print_key_value(key="UCX_NET_DEVICES", value=f"{ucp.get_config()['NET_DEVICES']}")
    print_separator(separator="=")
    if args.object_type == "numpy":
        print_key_value(key="Device(s)", value="CPU-only")
        s_aff = (
            args.server_cpu_affinity
            if args.server_cpu_affinity >= 0
            else "affinity not set"
        )
        c_aff = (
            args.client_cpu_affinity
            if args.client_cpu_affinity >= 0
            else "affinity not set"
        )
        print_key_value(key="Server CPU", value=f"{s_aff}")
        print_key_value(key="Client CPU", value=f"{c_aff}")
    else:
        print_key_value(key="Device(s)", value=f"{args.server_dev}, {args.client_dev}")
    print_separator(separator="=")
    print_key_value("Bandwidth (average)", value=f"{bw_avg}/s")
    print_key_value("Bandwidth (median)", value=f"{bw_med}/s")
    print_key_value("Latency (average)", value=f"{lat_avg} ns")
    print_key_value("Latency (median)", value=f"{lat_med} ns")
    if not args.no_detailed_report:
        print_separator(separator="=")
        print_key_value(key="Iterations", value="Bandwidth, Latency")
        print_separator(separator="-")
        for i, t in enumerate(times):
            ts = format_bytes(2 * args.n_bytes / t)
            lat = int(t * 1e9 / 2)
            print_key_value(key=i, value=f"{ts}/s, {lat}ns")
Exemplo n.º 6
0
def client(queue, port, server_address, args):
    if args.client_cpu_affinity >= 0:
        os.sched_setaffinity(0, [args.client_cpu_affinity])

    import numpy as np

    if args.object_type == "numpy":
        import numpy as xp
    elif args.object_type == "cupy":
        import cupy as xp

        xp.cuda.runtime.setDevice(args.client_dev)
    else:
        import cupy as xp

        import rmm

        rmm.reinitialize(
            pool_allocator=True,
            managed_memory=False,
            initial_pool_size=args.rmm_init_pool_size,
            devices=[args.client_dev],
        )
        xp.cuda.runtime.setDevice(args.client_dev)
        xp.cuda.set_allocator(rmm.rmm_cupy_allocator)

    ctx = ucx_api.UCXContext(feature_flags=(
        ucx_api.Feature.AM if args.enable_am is True else ucx_api.Feature.TAG,
    ))
    worker = ucx_api.UCXWorker(ctx)
    register_am_allocators(args, worker)
    ep = ucx_api.UCXEndpoint.create(
        worker,
        server_address,
        port,
        endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0),
    )

    send_msg = xp.arange(args.n_bytes, dtype="u1")
    if args.reuse_alloc:
        recv_msg = xp.zeros(args.n_bytes, dtype="u1")

    if args.enable_am:
        blocking_am_send(worker, ep, send_msg)
        blocking_am_recv(worker, ep)
    else:
        wireup_recv = bytearray(len(WireupMessage))
        blocking_send(worker, ep, WireupMessage)
        blocking_recv(worker, ep, wireup_recv)

    op_lock = Lock()
    finished = [0]
    outstanding = [0]

    def maybe_progress():
        while outstanding[0] >= args.max_outstanding:
            worker.progress()

    def op_started():
        with op_lock:
            outstanding[0] += 1

    def op_completed():
        with op_lock:
            outstanding[0] -= 1
            finished[0] += 1

    if args.cuda_profile:
        xp.cuda.profiler.start()

    times = []
    for i in range(args.n_iter):
        start = clock()

        if args.enable_am:
            blocking_am_send(worker, ep, send_msg)
            blocking_am_recv(worker, ep)
        else:
            if not args.reuse_alloc:
                recv_msg = xp.zeros(args.n_bytes, dtype="u1")

            if args.delay_progress:
                maybe_progress()
                non_blocking_send(worker, ep, send_msg, op_started,
                                  op_completed)
                maybe_progress()
                non_blocking_recv(worker, ep, recv_msg, op_started,
                                  op_completed)
            else:
                blocking_send(worker, ep, send_msg)
                blocking_recv(worker, ep, recv_msg)

        stop = clock()
        times.append(stop - start)

    if args.delay_progress:
        while finished[0] != 2 * args.n_iter:
            worker.progress()

    if args.cuda_profile:
        xp.cuda.profiler.stop()

    assert len(times) == args.n_iter
    delay_progress_str = (f"True ({args.max_outstanding})"
                          if args.delay_progress is True else "False")
    print("Roundtrip benchmark")
    print("--------------------------")
    print(f"n_iter          | {args.n_iter}")
    print(f"n_bytes         | {format_bytes(args.n_bytes)}")
    print(f"object          | {args.object_type}")
    print(f"reuse alloc     | {args.reuse_alloc}")
    print(f"transfer API    | {'AM' if args.enable_am else 'TAG'}")
    print(f"delay progress  | {delay_progress_str}")
    print(f"UCX_TLS         | {ucp.get_config()['TLS']}")
    print(f"UCX_NET_DEVICES | {ucp.get_config()['NET_DEVICES']}")
    print("==========================")
    if args.object_type == "numpy":
        print("Device(s)       | CPU-only")
        s_aff = (args.server_cpu_affinity
                 if args.server_cpu_affinity >= 0 else "affinity not set")
        c_aff = (args.client_cpu_affinity
                 if args.client_cpu_affinity >= 0 else "affinity not set")
        print(f"Server CPU      | {s_aff}")
        print(f"Client CPU      | {c_aff}")
    else:
        print(f"Device(s)       | {args.server_dev}, {args.client_dev}")
    avg = format_bytes(2 * args.n_iter * args.n_bytes / sum(times))
    med = format_bytes(2 * args.n_bytes / np.median(times))
    print(f"Average         | {avg}/s")
    print(f"Median          | {med}/s")
    if not args.no_detailed_report:
        print("--------------------------")
        print("Iterations")
        print("--------------------------")
        for i, t in enumerate(times):
            ts = format_bytes(2 * args.n_bytes / t)
            ts = (" " * (9 - len(ts))) + ts
            print("%03d         |%s/s" % (i, ts))