def all_reduce(tensor: torch.Tensor, op=ReduceOp.SUM, comm: MPI.COMM_WORLD = None) -> torch.Tensor: param_numpy = tensor.numpy() param_output = np.empty(param_numpy.shape, dtype=param_numpy.dtype) if comm is None: comm = _get_comm() comm.Allreduce(param_numpy, param_output, op=op.value) tensor = torch.from_numpy(param_output) return tensor
def all_reduce(tensor: EagerTensor, op=ReduceOp.SUM, comm: MPI.COMM_WORLD = None) -> EagerTensor: param_numpy = tensor.numpy() original_shape = param_numpy.shape param_numpy_flatten = param_numpy.flatten() param_output = np.empty(param_numpy_flatten.shape, dtype=param_numpy.dtype) if comm is None: comm = _get_comm() comm.Allreduce(param_numpy_flatten, param_output, op=op.value) param_output = np.reshape(param_output, original_shape) tensor = tutils.to_tensor(param_output) return tensor
def test_synchronize(self): """ Make sure that we can make the overlap spaces accurate. """ for case in self.cases: space.initialize_space(case['shape']) data = np.random.randn(*case['shape']).astype(case['dtype']) cpu_data = np.empty_like(data) comm.Allreduce(data, cpu_data) g = Grid(case['dtype']) self.assertRaises(TypeError, g.synchronize) # No overlap. # Test with-overlap cases as well. for k in range(1, 4): g = Grid(case['dtype'], x_overlap=k) # Overwrite entire grid data = np.random.randn(*case['shape']).astype(case['dtype']) cpu_data = np.empty_like(data) comm.Allreduce(data, cpu_data) cpu_raw_bad = get_cpu_raw(cpu_data, k) cpu_raw_bad[:k, :, :] += 1 # Mess up padding areas. cpu_raw_bad[-k:, :, :] += 1 drv.memcpy_htod(g.data.ptr, cpu_raw_bad) # Prove that the data is not synchronized at this time. cpu_raw = get_cpu_raw(cpu_data, k) xx = case['shape'][0] gd = g._get_raw() self.assertTrue((gd[:k, :, :] != cpu_raw[:k, :, :]).all()) self.assertTrue((gd[-k:, :, :] != cpu_raw[-k:, :, :]).all()) g.synchronize() # Synchronize the overlapping data. # Make sure that the overlap data is accurate. gd = g._get_raw() self.assertTrue((gd[:k, :, :] == cpu_raw[:k, :, :]).all()) self.assertTrue((gd[-k:, :, :] == cpu_raw[-k:, :, :]).all()) comm.Barrier() # Wait for other mpi nodes to finish.
def test_recover(self): """ Make sure we can store and retrieve information from the GPU. """ for case in self.cases: space.initialize_space(case['shape']) data = np.random.randn(*case['shape']).astype(case['dtype']) cpu_data = np.empty_like(data) comm.Allreduce(data, cpu_data) g = Grid(cpu_data) gpu_data = g.get() if comm.Get_rank() == 0: self.assertTrue((cpu_data == gpu_data).all()) # Test with-overlap cases as well. for k in range(1, 3): g = Grid(cpu_data, x_overlap=k) gpu_data = g.get() if comm.Get_rank() == 0: self.assertTrue((cpu_data == gpu_data).all()) cpu_raw = get_cpu_raw(cpu_data, k) self.assertTrue((cpu_raw == g._get_raw()).all())