def test_horovod_alltoall_equal_split(self): """Test that the alltoall correctly distributes 1D tensors with default splitting.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if NCCL version < 2.7.0 if hvd.nccl_built() and hvd.nccl_built() < 2700: self.skipTest( "NCCL-based Alltoall requires NCCL version >= 2.7.0.") dtypes = ['int32', 'int64', 'float32', 'float64'] dims = [1, 2, 3] ctx = self._current_context() for dtype, dim in itertools.product(dtypes, dims): vals = [] for i in range(size): vals += [i] * (rank + 1) tensor = mx.ndarray.array(vals, dtype=dtype, ctx=ctx) for _ in range(dim - 1): tensor = mx.ndarray.expand_dims(tensor, axis=1) tensor = mx.ndarray.concat(tensor, tensor, dim=1) collected = hvd.alltoall(tensor) assert collected.min( ) == rank, 'hvd.alltoall produces incorrect collected tensor' assert collected.max( ) == rank, 'hvd.alltoall produces incorrect collected tensor' assert collected.size == size * (size + 1) // 2 * 2**( dim - 1), 'hvd.alltoall collected wrong number of values'
def test_horovod_alltoall_type_error(self): """Test that the alltoall returns an error if the tensor types differ across the processes.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if there is only one worker. if size == 1: self.skipTest("Only one worker available") # This test does not apply if NCCL version < 2.7.0 if hvd.nccl_built() and hvd.nccl_built() < 2700: self.skipTest( "NCCL-based Alltoall requires NCCL version >= 2.7.0.") ctx = self._current_context() if rank % 2: tensor = mx.ndarray.empty([size], dtype='int32', ctx=ctx) else: tensor = mx.ndarray.empty([size], dtype='float32', ctx=ctx) try: output = hvd.alltoall(tensor) output.wait_to_read() assert False, 'hvd.alltoall did not throw error' except (MXNetError, RuntimeError): pass
def test_horovod_alltoall_rank_error(self): """Test that the alltoall returns an error if any dimension besides the first is different among the tensors being processed.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if there is only one worker. if size == 1: self.skipTest("Only one worker available") # This test does not apply if NCCL version < 2.7.0 if hvd.nccl_built() and hvd.nccl_built() < 2700: self.skipTest( "NCCL-based Alltoall requires NCCL version >= 2.7.0.") ctx = self._current_context() tensor_size = [2 * size] * 3 tensor_size[1] = 10 * (rank + 1) tensor = mx.ndarray.ones(shape=tensor_size, ctx=ctx) try: output = hvd.alltoall(tensor) output.wait_to_read() assert False, 'hvd.alltoall did not throw error' except (MXNetError, RuntimeError): pass
def test_horovod_alltoall(self): """Test that the alltoall correctly distributes 1D, 2D, and 3D tensors.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if NCCL version < 2.7.0 if hvd.nccl_built() and hvd.nccl_built() < 2700: self.skipTest("NCCL-based Alltoall requires NCCL version >= 2.7.0.") dtypes = ['int32', 'int64', 'float32', 'float64'] dims = [1,2,3] ctx = self._current_context() for dtype, dim in itertools.product(dtypes, dims): vals = [] for i in range(size): vals += [i] * (rank + 1) tensor = mx.ndarray.array(vals, dtype=dtype, ctx=ctx) for _ in range(dim - 1): tensor = mx.ndarray.expand_dims(tensor, axis=1) tensor = mx.ndarray.concat(tensor, tensor, dim=1) splits = mx.ndarray.array([rank + 1] * size, dtype='int32', ctx=ctx) collected, received_splits = hvd.alltoall(tensor, splits) assert collected.min() == rank, 'hvd.alltoall produces incorrect collected tensor' assert collected.max() == rank, 'hvd.alltoall produces incorrect collected tensor' assert collected.size == size * (size + 1) // 2 * 2**(dim - 1), 'hvd.alltoall collected wrong number of values' self.assertSequenceEqual(received_splits.asnumpy().tolist(), [rk + 1 for rk in range(size)], "hvd.alltoall returned incorrect received_splits")
def test_horovod_alltoall_splits_type_error(self): """Test that the alltoall returns an error if the splits tensor does not contain 32-bit integers.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if using gloo controller if hvd.gloo_enabled(): self.skipTest( "Alltoall currently does not support Gloo controller.") # This test does not apply if NCCL version < 2.7.0 if hvd.nccl_built() and hvd.nccl_built() < 2700: self.skipTest( "NCCL-based Alltoall requires NCCL version >= 2.7.0.") ctx = self._current_context() tensor = mx.ndarray.empty([size], ctx=ctx) splits = mx.ndarray.ones([size], dtype='float32', ctx=ctx) try: hvd.alltoall(tensor, splits) assert False, 'hvd.alltoall did not throw error' except (MXNetError, ValueError): pass
def test_horovod_alltoall_equal_split_length_error(self): """Test that the alltoall with default splitting returns an error if the first dimension of tensor is not a multiple of the number of workers.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if there is only one worker. if size == 1: self.skipTest("Only one worker available") # This test does not apply if using gloo controller if hvd.gloo_enabled(): self.skipTest( "Alltoall currently does not support Gloo controller.") # This test does not apply if NCCL version < 2.7.0 if hvd.nccl_built() and hvd.nccl_built() < 2700: self.skipTest( "NCCL-based Alltoall requires NCCL version >= 2.7.0.") ctx = self._current_context() tensor = mx.ndarray.empty([size + 1], ctx=ctx) try: hvd.alltoall(tensor) assert False, 'hvd.alltoall did not throw error' except (MXNetError, RuntimeError): pass
def test_horovod_alltoall_splits_error(self): """Test that the alltoall returns an error if the sum of the splits entries exceeds the first dimension of the input tensor.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if NCCL version < 2.7.0 if hvd.nccl_built() and hvd.nccl_built() < 2700: self.skipTest( "NCCL-based Alltoall requires NCCL version >= 2.7.0.") ctx = self._current_context() tensor = mx.ndarray.empty([size - 1], ctx=ctx) splits = mx.ndarray.ones([size], dtype='int32', ctx=ctx) try: hvd.alltoall(tensor, splits) assert False, 'hvd.alltoall did not throw error' except (MXNetError, RuntimeError): pass