Example #1
0
 def collective_fn(input_tensor, output_tensor, context):
     size = gloo_util.get_tensor_n_elements(input_tensor)
     world_size = self._gloo_context.size
     pygloo.reduce_scatter(
         context, gloo_util.get_tensor_ptr(input_tensor),
         gloo_util.get_tensor_ptr(output_tensor), size,
         [size // world_size for _ in range(world_size)],
         gloo_util.get_gloo_tensor_dtype(output_tensor),
         gloo_util.get_gloo_reduce_op(reducescatter_options.reduceOp))
def test_reduce_scatter(rank, world_size, fileStore_path):
    '''
    rank  # Rank of this process within list of participating processes
    world_size  # Number of participating processes
    '''
    if rank == 0:
        if os.path.exists(fileStore_path):
            shutil.rmtree(fileStore_path)
        os.makedirs(fileStore_path)
    else:
        time.sleep(0.5)

    context = pygloo.rendezvous.Context(rank, world_size)

    attr = pygloo.transport.tcp.attr("localhost")
    # Perform rendezvous for TCP pairs
    dev = pygloo.transport.tcp.CreateDevice(attr)

    fileStore = pygloo.rendezvous.FileStore(fileStore_path)
    store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore)

    context.connectFullMesh(store, dev)

    sendbuf = np.array(
        [i + 1 for i in range(sum([j + 1 for j in range(world_size)]))],
        dtype=np.float32)
    sendptr = sendbuf.ctypes.data

    recvbuf = np.zeros((rank + 1, ), dtype=np.float32)
    recvptr = recvbuf.ctypes.data
    recvElems = [i + 1 for i in range(world_size)]

    # sendbuf = torch.Tensor([i+1 for i in range(sum([j+1 for j in range(world_size)]))]).float()
    # sendptr = sendbuf.data_ptr()
    # recvbuf = torch.zeros(rank+1).float()
    # recvptr = recvbuf.data_ptr()

    data_size = sendbuf.size if isinstance(
        sendbuf, np.ndarray) else sendbuf.numpy().size
    datatype = pygloo.glooDataType_t.glooFloat32
    op = pygloo.ReduceOp.SUM

    pygloo.reduce_scatter(context, sendptr, recvptr, data_size, recvElems,
                          datatype, op)

    print(f"rank {rank} sends {sendbuf}, receives {recvbuf}")