def multi_node_mean(self, array_a, array_b): # The name is allreduce but actually a mean # Sigma(a, all-procs)/n -> b or # Sigma(b, all-procs)/n -> b if array_a is None if chainer.is_debug(): self.check_ready_to_allreduce(array_a, array_b) is_float16 = array_b.dtype == numpy.float16 if array_a is None: buffer_a = mpi4py.MPI.IN_PLACE elif is_float16: assert array_a.dtype == array_b.dtype buffer_a = _memory_utility.array_to_buffer_object( array_a.astype(numpy.float32)) else: buffer_a = _memory_utility.array_to_buffer_object(array_a) if is_float16: array_b32 = array_b.astype(numpy.float32) else: array_b32 = array_b buffer_b = _memory_utility.array_to_buffer_object(array_b32) self.mpi_comm.Allreduce(buffer_a, buffer_b) if is_float16: xp = chainer.backend.get_array_module(array_b) xp.copyto(array_b, array_b32.astype(numpy.float16), casting='no') array_b *= 1.0 / self.mpi_comm.size if chainer.is_debug(): self.ensure_all_finite(array_b)
def allreduce(self, x): """A primitive of inter-process allreduce communication. This method tries to invoke allreduce communication within the communicator. All processes in the communicator are expected to invoke ``allreduce()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. Note that this method can only handle the same shapes of data over all processes, and cannot handle tuple data. If ``x`` is numpy array, the received data will also be allocated as numpy array. Additionally, when ``x`` is cupy array, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: x (numpy/cupy array): An array to apply allreduce operation. Returns: ys (numpy/cupy array): An array that allreduce (currently SUM only) has been applied. """ chainer.utils.experimental( 'chainermn.communicators.CommunicatorBase.allreduce') msgtype = _MessageType(x) _check_dtype('allreduce', msgtype) if msgtype.is_tuple: raise TypeError('allreduce cannot handle tuple data') if msgtype.is_tuple: raise TypeError('allreduce cannot handle tuple data') xp = chainer.cuda.get_array_module(x) # TODO(kuenishi): do we check all messages have same shape and dims? # Source buffer sbuf = _memory_utility.array_to_buffer_object( x, _get_mpi_type(msgtype)) # Destination buffer dbuf = xp.empty([numpy.prod(msgtype.shapes[0])], dtype=msgtype.dtype) dbuf = _memory_utility.array_to_buffer_object( dbuf, _get_mpi_type(msgtype)) self.mpi_comm.Allreduce(sbuf, dbuf) return dbuf.reshape(msgtype.shapes[0])
def recv(self, source, tag): """A primitive of inter-process receiver. This method tries to receive numpy-array from target process. The target process is expected to invoke ``send()``. This method relies on mpi4py fast communication optimized for numpy arrays, which discards any information attached to chainer.Variable objects. Please be sure. If the corresponding ``send()`` is invoked with cupy array, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: source (int): Target process specifier. tag (int): Message ID (MPI feature). Returns: data (tuple of numpy/cupy array or numpy/cupy array): Received data. If ``send()`` is invoked with tuple data, it is also tuple. Otherwise, it is a vanilla numpy/cupy array. """ chainer.utils.experimental( 'chainermn.communicators.MpiCommunicatorBase.recv') msgtype = self.mpi_comm.recv(source=source, tag=tag) xp = msgtype.get_array_module() if msgtype.is_tuple: msg = [] for shape in msgtype.shapes: buf = xp.empty( [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype) rtype = _get_mpi_type(msgtype) self.mpi_comm.Recv( _memory_utility.array_to_buffer_object(buf, rtype), source=source, tag=tag) msg.append(buf.reshape(shape)) return tuple(msg) else: assert len(msgtype.shapes) == 1 shape = msgtype.shapes[0] buf = xp.empty( [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype) rtype = _get_mpi_type(msgtype) self.mpi_comm.Recv( _memory_utility.array_to_buffer_object(buf, rtype), source=source, tag=tag) return buf.reshape(shape)
def bcast(self, x, root=0): """A primitive of inter-process broadcast communication. This method tries to invoke broadcast communication within the communicator. All processes in the communicator are expected to invoke ``broadcast()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. If ``bcast()`` is invoked with cupy array in the root process, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: x (numpy/cupy array): Array to be broadcasted. root (int): Rank of root process. Returns: ys (tuple of numpy/cupy array): Received arrays. """ chainer.utils.experimental( 'chainermn.communicators.MpiCommunicatorBase.bcast') is_master = self.mpi_comm.rank == root if is_master: msgtype = _MessageType(x) _check_dtype('bcast', msgtype) if msgtype.is_tuple: raise TypeError('Tuple data cannot be broadcasted') msgtype = self.mpi_comm.bcast(msgtype, root) shape = msgtype.shapes[0] buf = _memory_utility.array_to_buffer_object( x, _get_mpi_type(msgtype)) self.mpi_comm.Bcast(buf, root) return x else: msgtype = self.mpi_comm.bcast(None, root) xp = msgtype.get_array_module() shape = msgtype.shapes[0] buf = xp.empty( [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype) buftype = _get_mpi_type(msgtype) self.mpi_comm.Bcast( _memory_utility.array_to_buffer_object(buf, buftype), root) return buf.reshape(shape)
def allgather(self, x): chainer.utils.experimental( 'chainermn.communicators.MpiCommunicatorBase.allgather') msgtype = _MessageType(x) _check_dtype('allgather', msgtype) msgtypes = self.mpi_comm.allgather(msgtype) _check_dtypes_are_same(msgtypes) # Type check. for msgtype in msgtypes: if msgtype.is_tuple: raise TypeError('allgather cannot handle tuple data') assert len(msgtype.shapes) == 1 # Collective communication. xp = chainer.backend.get_array_module(x) shapes = [msgtype.shapes[0] for msgtype in msgtypes] sbuf = _memory_utility.array_to_buffer_object( x, _get_mpi_type(msgtype)) rlens = [chainer.utils.size_of_shape(s) for s in shapes] rbuf = xp.empty([sum(rlens)], dtype=msgtype.dtype) if xp is not numpy: chainer.cuda.Stream.null.synchronize() self.mpi_comm.Allgatherv( sbuf, [_memory_utility.get_device_memory_pointer(rbuf), (rlens, _cnt_to_dsp(rlens)), _get_mpi_type(msgtype)]) ys = [rbuf[i:i + l].reshape(s) for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)] return tuple(ys)
def allreduce_grad(self, model): for param in _memory_utility.extract_params_set_grad(model): grad = param.grad is_float16 = param.grad.dtype == np.float16 if is_float16: grad = grad.astype(np.float32) buf = _memory_utility.array_to_buffer_object(grad) self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf) if is_float16: param.grad = grad.astype(np.float16) param.grad /= self.size
def bcast_data(self, model): for _, param in sorted(model.namedparams()): if param.data is not None: data = param.data is_float16 = param.data.dtype == numpy.float16 if is_float16: data = data.astype(numpy.float32) buf = _memory_utility.array_to_buffer_object(data) self.mpi_comm.Bcast(buf) if is_float16: param.data = data.astype(numpy.float16)
def __call__(self, trainer=None): # We need to delay MPI4py import. Please also note that _memory_utility # module also imports MPI4py. from chainermn.communicators._memory_utility \ import array_to_buffer_object import mpi4py.MPI for _, param in sorted(_namedpersistents(self.model)): if hasattr(param, 'dtype') and param.dtype == np.float32: buf = array_to_buffer_object(param) self.comm.Allreduce(mpi4py.MPI.IN_PLACE, buf) param /= self.comm.size else: pass # Integer persistent variables are ignored
def bcast_data(self, model): for _, param in sorted(model.namedparams()): buf = _memory_utility.array_to_buffer_object(param.data) self.mpi_comm.Bcast(buf)
def gather(self, x, root=0): """A primitive of inter-process gather communication. This method tries to invoke gather communication within the communicator. All processes in the communicator are expected to invoke ``gather()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. If ``x`` is numpy array, the received data will also be allocated as numpy array. Additionally, when ``x`` is cupy array, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: x (numpy/cupy array): Array to be gathered. root (int): Rank of root process. Returns: ys (tuple of numpy/cupy array): Received arrays. ``None`` for non-root processes. """ chainer.utils.experimental( 'chainermn.communicators.MpiCommunicatorBase.gather') is_master = self.mpi_comm.rank == root msgtype = _MessageType(x) _check_dtype('gather', msgtype) msgtypes = self.mpi_comm.gather(msgtype, root) if is_master: _check_dtypes_are_same(msgtypes) for msgtype in msgtypes: if msgtype.is_tuple: raise TypeError('gather cannot handle tuple data') assert len(msgtype.shapes) == 1 xp = chainer.backend.get_array_module(x) sbuf = _memory_utility.array_to_buffer_object( x, _get_mpi_type(msgtype)) shapes = [mty.shapes[0] for mty in msgtypes] rlens = [chainer.utils.size_of_shape(s) for s in shapes] rbuf = xp.empty([sum(rlens)], dtype=msgtype.dtype) if xp is not numpy: chainer.cuda.Stream.null.synchronize() self.mpi_comm.Gatherv( sbuf, [_memory_utility.get_device_memory_pointer(rbuf), (rlens, _cnt_to_dsp(rlens)), _get_mpi_type(msgtype)], root) ys = [rbuf[i:i + l].reshape(s) for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)] return tuple(ys) else: sbuf = _memory_utility.array_to_buffer_object( x, _get_mpi_type(msgtype)) self.mpi_comm.Gatherv(sbuf, None, root) return None
def broadcast_naive(mpi_comm, model): for _, param in sorted(model.namedparams()): buf = _memory_utility.array_to_buffer_object(param.data) mpi_comm.Bcast(buf)
def allreduce_grad(self, model): for _, param in sorted(model.namedparams()): buf = _memory_utility.array_to_buffer_object(param.grad) self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf) param.grad /= self.size
def allreduce_grad(self, model): for param in _memory_utility.extract_params(model): buf = _memory_utility.array_to_buffer_object(param.grad) self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf) param.grad /= self.size
def allreduce_grad(self, model): for param in _memory_utility.extract_params_set_grad(model): buf = _memory_utility.array_to_buffer_object(param.grad) self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf) param.grad /= self.size
def gather(self, x, root=0): """A primitive of inter-process gather communication. This method tries to invoke gather communication within the communicator. All processes in the communicator are expected to invoke ``gather()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. If ``x`` is numpy array, the received data will also be allocated as numpy array. Additionally, when ``x`` is cupy array, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: x (numpy/cupy array): Array to be gathered. root (int): Rank of root process. Returns: ys (tuple of numpy/cupy array): Received arrays. ``None`` for non-root processes. """ chainer.utils.experimental( 'chainermn.communicators.MpiCommunicatorBase.gather') is_master = self.mpi_comm.rank == root msgtype = _MessageType(x) _check_dtype('gather', msgtype) msgtypes = self.mpi_comm.gather(msgtype, root) if is_master: _check_dtypes_are_same(msgtypes) for msgtype in msgtypes: if msgtype.is_tuple: raise TypeError('gather cannot handle tuple data') assert len(msgtype.shapes) == 1 xp = chainer.backend.get_array_module(x) sbuf = _memory_utility.array_to_buffer_object( x, _get_mpi_type(msgtype)) shapes = [mty.shapes[0] for mty in msgtypes] rlens = [numpy.prod(s, dtype=int) for s in shapes] rbuf = xp.empty([sum(rlens)], dtype=msgtype.dtype) if xp is not numpy: chainer.cuda.Stream.null.synchronize() self.mpi_comm.Gatherv(sbuf, [ _memory_utility.get_device_memory_pointer(rbuf), (rlens, _cnt_to_dsp(rlens)), _get_mpi_type(msgtype) ], root) ys = [ rbuf[i:i + l].reshape(s) for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes) ] return tuple(ys) else: sbuf = _memory_utility.array_to_buffer_object( x, _get_mpi_type(msgtype)) self.mpi_comm.Gatherv(sbuf, None, root) return None
def scatter(self, xs, root=0): """A primitive of inter-process scatter communication. This method tries to invoke scatter communication within the communicator. All processes in the communicator are expected to invoke ``scatter()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. If ``xs`` is tuple, each element is send to different processes. The length of the tuple must be the same as the communicator size. If ``xs`` is ``numpy.ndarrray``, it is splitted with the first axis and sent to different processes. For slave processes, ``xs`` is allowed to be any value (will be ignored). If ``scatter()`` is invoked with cupy array in the root process, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: xs (tuple of numpy/cupy array): Arrays to be scattered. root (int): Rank of root process. Returns: ys (numpy/cupy array): Received arrays. """ chainer.utils.experimental( 'chainermn.communicators.CommunicatorBase.scatter') is_master = self.mpi_comm.rank == root if is_master: # Type check. msgtype = _MessageType(xs) _check_dtype('scatter', msgtype) if msgtype.is_tuple: if len(msgtype.shapes) != self.size: raise ValueError( 'the length of xs must be consistent ' 'with communicator size') xp = chainer.backend.get_array_module(*xs) msgtype = tuple([_MessageType(x) for x in xs]) shapes = [mty.shapes[0] for mty in msgtype] # concatenate([x.reshape(-1) ... ], axis=0) will fail xs = xp.concatenate([x.reshape(1, -1) for x in xs], axis=1) else: assert len(msgtype.shapes) == 1 if msgtype.shapes[0][0] != self.mpi_comm.size: raise ValueError( 'scatter received inconsistent number of inputs ' 'with communicator size') xp = chainer.backend.get_array_module(xs) msgtype = tuple([_MessageType(xs[0]) for _ in range(self.size)]) shapes = [xs.shape[1:] for _ in range(self.size)] msgtype = self.mpi_comm.scatter(msgtype, root) shape = msgtype.shapes[0] # Collective communication. slens = [chainer.utils.size_of_shape(s) for s in shapes] sbuf = _memory_utility.get_device_memory_pointer(xs) rbuf = xp.empty( [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype) rtype = _get_mpi_type(msgtype) if xp is not numpy: chainer.cuda.Stream.null.synchronize() self.mpi_comm.Scatterv( [sbuf, (slens, _cnt_to_dsp(slens)), _get_mpi_type(msgtype)], _memory_utility.array_to_buffer_object(rbuf, rtype), root) return rbuf.reshape(shape) else: # slave processes msgtypes = self.mpi_comm.scatter(None, root) xp = msgtypes.get_array_module() shape = msgtypes.shapes[0] rbuf = xp.empty( [chainer.utils.size_of_shape(shape)], dtype=msgtypes.dtype) rtype = _get_mpi_type(msgtypes) self.mpi_comm.Scatterv( None, _memory_utility.array_to_buffer_object(rbuf, rtype), root) return rbuf.reshape(shape)
def recv(self, source, tag): """A primitive of inter-process receiver. This method tries to receive numpy-array from target process. The target process is expected to invoke ``send()``. This method relies on mpi4py fast communication optimized for numpy arrays, which discards any information attached to chainer.Variable objects. Please be sure. If the corresponding ``send()`` is invoked with cupy array, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: source (int): Target process specifier. tag (int): Message ID (MPI feature). Returns: data (tuple of numpy/cupy array or numpy/cupy array): Received data. If ``send()`` is invoked with tuple data, it is also tuple. Otherwise, it is a vanilla numpy/cupy array. """ chainer.utils.experimental( 'chainermn.communicators.MpiCommunicatorBase.recv') msgtype = self.mpi_comm.recv(source=source, tag=tag) xp = msgtype.get_array_module() if numpy.float16 == msgtype.dtype: comm_dtype = numpy.float32 else: comm_dtype = msgtype.dtype if msgtype.is_tuple: msg = [] for shape in msgtype.shapes: buf = xp.empty( [chainer.utils.size_of_shape(shape)], dtype=comm_dtype) rtype = _get_mpi_type(msgtype) self.mpi_comm.Recv( _memory_utility.array_to_buffer_object(buf, rtype), source=source, tag=tag) if numpy.float16 == msgtype.dtype: buf = buf.astype(numpy.float16) msg.append(buf.reshape(shape)) return tuple(msg) else: assert len(msgtype.shapes) == 1 shape = msgtype.shapes[0] buf = xp.empty([chainer.utils.size_of_shape(shape)], dtype=comm_dtype) rtype = _get_mpi_type(msgtype) self.mpi_comm.Recv( _memory_utility.array_to_buffer_object(buf, rtype), source=source, tag=tag) if numpy.float16 == msgtype.dtype: buf = buf.astype(numpy.float16) return buf.reshape(shape)
def bcast_data(self, model): for _, param in sorted(model.namedparams()): if param.data is not None: buf = _memory_utility.array_to_buffer_object(param.data) self.mpi_comm.Bcast(buf)
def alltoall(self, xs): """A primitive of inter-process all-to-all function. This method tries to invoke all-to-all communication within the communicator. All processes in the communicator are expected to invoke ``alltoall()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. Args: xs (tuple of numpy.ndarray) Returns: ys (tuple of numpy.ndarray): Received arrays. The length of tuple equals to the communicator size. """ chainer.utils.experimental( 'chainermn.communicators.MpiCommunicatorBase.alltoall') if len(xs) != self.size: raise ValueError( 'The length of data must be same as communicator size.') # Type check. for x in xs: if x.dtype != numpy.float32: raise ValueError( 'alltoall only support dtype == numpy.float32') # Mediate #axes of arrays. sndims = numpy.array([x.ndim for x in xs], dtype=numpy.int32) rndims = numpy.empty(self.size, dtype=numpy.int32) self.mpi_comm.Alltoall( [sndims, mpi4py.MPI.INT], [rndims, mpi4py.MPI.INT]) # Arbitrate shapes of arrays. sshapes = numpy.hstack([x.shape for x in xs]).astype(numpy.int32) rshapes = numpy.empty(sum(rndims), dtype=numpy.int32) self.mpi_comm.Alltoallv( [sshapes, (sndims, _cnt_to_dsp(sndims)), mpi4py.MPI.INT], [rshapes, (rndims, _cnt_to_dsp(rndims)), mpi4py.MPI.INT]) shapes = [rshapes[i:i + l] for i, l in zip(_cnt_to_dsp(rndims), rndims)] # Collective communication. slens = [numpy.prod(x.shape) for x in xs] xp = chainer.cuda.get_array_module(xs[0]) sbuf = xp.hstack([x.reshape(-1) for x in xs]) rlens = [numpy.prod(s) for s in shapes] rbuf = numpy.empty(sum(rlens), dtype=numpy.float32) if xp is not numpy: sbuf = _memory_utility.array_to_buffer_object(sbuf)[0] chainer.cuda.Stream.null.synchronize() self.mpi_comm.Alltoallv( [sbuf, (slens, _cnt_to_dsp(slens)), mpi4py.MPI.FLOAT], [rbuf, (rlens, _cnt_to_dsp(rlens)), mpi4py.MPI.FLOAT]) ys = [rbuf[i:i + l].reshape(s) for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)] return tuple(ys)
def scatter(self, xs, root=0): """A primitive of inter-process scatter communication. This method tries to invoke scatter communication within the communicator. All processes in the communicator are expected to invoke ``scatter()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. If ``xs`` is tuple, each element is send to different processes. The length of the tuple must be the same as the communicator size. If ``xs`` is ``numpy.ndarrray``, it is splitted with the first axis and sent to different processes. For slave processes, ``xs`` is allowed to be any value (will be ignored). If ``scatter()`` is invoked with cupy array in the root process, the returned array will be placed at current device (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``) regardless of which device the argument is placed at remote nodes. Args: xs (tuple of numpy/cupy array): Arrays to be scattered. root (int): Rank of root process. Returns: ys (numpy/cupy array): Received arrays. """ chainer.utils.experimental( 'chainermn.communicators.CommunicatorBase.scatter') is_master = self.mpi_comm.rank == root if is_master: # Type check. msgtype = _MessageType(xs) _check_dtype('scatter', msgtype) if msgtype.is_tuple: if len(msgtype.shapes) != self.size: raise ValueError('the length of xs must be consistent ' 'with communicator size') xp = chainer.backend.get_array_module(*xs) msgtype = tuple([_MessageType(x) for x in xs]) shapes = [mty.shapes[0] for mty in msgtype] # concatenate([x.reshape(-1) ... ], axis=0) will fail xs = xp.concatenate([x.reshape(1, -1) for x in xs], axis=1) else: assert len(msgtype.shapes) == 1 if msgtype.shapes[0][0] != self.mpi_comm.size: raise ValueError( 'scatter received inconsistent number of inputs ' 'with communicator size') xp = chainer.backend.get_array_module(xs) msgtype = tuple( [_MessageType(xs[0]) for _ in range(self.size)]) shapes = [xs.shape[1:] for _ in range(self.size)] msgtype = self.mpi_comm.scatter(msgtype, root) shape = msgtype.shapes[0] # Collective communication. slens = [numpy.prod(s, dtype=int) for s in shapes] sbuf = _memory_utility.get_device_memory_pointer(xs) rbuf = xp.empty([numpy.prod(shape, dtype=int)], dtype=msgtype.dtype) rtype = _get_mpi_type(msgtype) if xp is not numpy: chainer.cuda.Stream.null.synchronize() self.mpi_comm.Scatterv( [sbuf, (slens, _cnt_to_dsp(slens)), _get_mpi_type(msgtype)], _memory_utility.array_to_buffer_object(rbuf, rtype), root) return rbuf.reshape(shape) else: # slave processes msgtypes = self.mpi_comm.scatter(None, root) xp = msgtypes.get_array_module() shape = msgtypes.shapes[0] rbuf = xp.empty([numpy.prod(shape, dtype=int)], dtype=msgtypes.dtype) rtype = _get_mpi_type(msgtypes) self.mpi_comm.Scatterv( None, _memory_utility.array_to_buffer_object(rbuf, rtype), root) return rbuf.reshape(shape)
def scatter(self, xs, root=0): """A primitive of inter-process scatter communication. This method tries to invoke scatter communication within the communicator. All processes in the communicator are expected to invoke ``scatter()``. This method relies on mpi4py fast communication optimized for numpy arrays, as well as ``send()`` and ``recv()``. If ``xs`` is tuple, each element is send to different processes. The length of the tuple must be the same as the communicator size. If ``xs`` is ``numpy.ndarrray``, it is splitted with the first axis and sent to different processes. For slave processes, ``xs`` is allowed to be any value (will be ignored). Args: xs (tuple of numpy.array or numpy.array): Arrays to be scattered. root (int): Rank of root process. Returns: ys (numpy.ndarray): Received arrays. """ chainer.utils.experimental( 'chainermn.communicators.CommunicatorBase.scatter') is_master = self.mpi_comm.rank == root if is_master: # Type check. msgtype = _MessageType(xs) if msgtype.is_tuple: if xs[0].dtype != numpy.float32: raise TypeError( 'scatter only support dtype == numpy.float32') if len(msgtype.shapes) != self.size: raise ValueError('the length of xs must be consistent ' 'with communicator size') xp = chainer.cuda.get_array_module(*xs) msgtype = tuple([_MessageType(x) for x in xs]) shapes = [mty.shapes[0] for mty in msgtype] # concatenate([x.reshape(-1) ... ], axis=0) will fail xs = xp.concatenate([x.reshape(1, -1) for x in xs], axis=1) else: assert len(msgtype.shapes) == 1 if xs.dtype != numpy.float32: raise TypeError( 'scatter only support dtype == numpy.float32') if msgtype.shapes[0][0] != self.mpi_comm.size: raise ValueError( 'scatter received inconsistent number of inputs ' 'with communicator size') xp = chainer.cuda.get_array_module(xs) msgtype = tuple( [_MessageType(xs[0]) for _ in range(self.size)]) shapes = [xs.shape[1:] for _ in range(self.size)] msgtype = self.mpi_comm.scatter(msgtype, root) shape = msgtype.shapes[0] # Collective communication. slens = [numpy.prod(s) for s in shapes] sbuf = _memory_utility.array_to_buffer_object(xs)[0] rbuf = numpy.empty(numpy.prod(shape), dtype=numpy.float32) if xp is not numpy: chainer.cuda.Stream.null.synchronize() self.mpi_comm.Scatterv( [sbuf, (slens, _cnt_to_dsp(slens)), mpi4py.MPI.FLOAT], rbuf, root) return rbuf.reshape(shape) else: # slave processes msgtypes = self.mpi_comm.scatter(None, root) shape = msgtypes.shapes[0] rbuf = numpy.empty(numpy.prod(shape), dtype=numpy.float32) self.mpi_comm.Scatterv(None, rbuf, root) return rbuf.reshape(shape)