def collective_fn(input_tensor, output_tensor, context): size = gloo_util.get_tensor_n_elements(input_tensor) world_size = self._gloo_context.size pygloo.reduce_scatter( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), size, [size // world_size for _ in range(world_size)], gloo_util.get_gloo_tensor_dtype(output_tensor), gloo_util.get_gloo_reduce_op(reducescatter_options.reduceOp))
def test_reduce_scatter(rank, world_size, fileStore_path): ''' rank # Rank of this process within list of participating processes world_size # Number of participating processes ''' if rank == 0: if os.path.exists(fileStore_path): shutil.rmtree(fileStore_path) os.makedirs(fileStore_path) else: time.sleep(0.5) context = pygloo.rendezvous.Context(rank, world_size) attr = pygloo.transport.tcp.attr("localhost") # Perform rendezvous for TCP pairs dev = pygloo.transport.tcp.CreateDevice(attr) fileStore = pygloo.rendezvous.FileStore(fileStore_path) store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) context.connectFullMesh(store, dev) sendbuf = np.array( [i + 1 for i in range(sum([j + 1 for j in range(world_size)]))], dtype=np.float32) sendptr = sendbuf.ctypes.data recvbuf = np.zeros((rank + 1, ), dtype=np.float32) recvptr = recvbuf.ctypes.data recvElems = [i + 1 for i in range(world_size)] # sendbuf = torch.Tensor([i+1 for i in range(sum([j+1 for j in range(world_size)]))]).float() # sendptr = sendbuf.data_ptr() # recvbuf = torch.zeros(rank+1).float() # recvptr = recvbuf.data_ptr() data_size = sendbuf.size if isinstance( sendbuf, np.ndarray) else sendbuf.numpy().size datatype = pygloo.glooDataType_t.glooFloat32 op = pygloo.ReduceOp.SUM pygloo.reduce_scatter(context, sendptr, recvptr, data_size, recvElems, datatype, op) print(f"rank {rank} sends {sendbuf}, receives {recvbuf}")