예제 #1
0
    def multi_node_mean(self, array_a, array_b):
        # The name is allreduce but actually a mean
        # Sigma(a, all-procs)/n -> b or
        # Sigma(b, all-procs)/n -> b if array_a is None
        if chainer.is_debug():
            self.check_ready_to_allreduce(array_a, array_b)

        is_float16 = array_b.dtype == numpy.float16
        if array_a is None:
            buffer_a = mpi4py.MPI.IN_PLACE
        elif is_float16:
            assert array_a.dtype == array_b.dtype
            buffer_a = _memory_utility.array_to_buffer_object(
                array_a.astype(numpy.float32))
        else:
            buffer_a = _memory_utility.array_to_buffer_object(array_a)

        if is_float16:
            array_b32 = array_b.astype(numpy.float32)
        else:
            array_b32 = array_b
        buffer_b = _memory_utility.array_to_buffer_object(array_b32)

        self.mpi_comm.Allreduce(buffer_a, buffer_b)

        if is_float16:
            xp = chainer.backend.get_array_module(array_b)
            xp.copyto(array_b, array_b32.astype(numpy.float16), casting='no')

        array_b *= 1.0 / self.mpi_comm.size

        if chainer.is_debug():
            self.ensure_all_finite(array_b)
예제 #2
0
    def allreduce(self, x):
        """A primitive of inter-process allreduce communication.

        This method tries to invoke allreduce communication within the
        communicator. All processes in the communicator are expected to
        invoke ``allreduce()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        Note that this method can only handle the same shapes of data
        over all processes, and cannot handle tuple data.

        If ``x`` is numpy array, the received data will also be allocated
        as numpy array. Additionally, when ``x`` is cupy array, the returned
        array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            x (numpy/cupy array): An array to apply allreduce operation.

        Returns:
            ys (numpy/cupy array): An array that allreduce (currently SUM only)
                has been applied.

        """

        chainer.utils.experimental(
            'chainermn.communicators.CommunicatorBase.allreduce')

        msgtype = _MessageType(x)
        _check_dtype('allreduce', msgtype)

        if msgtype.is_tuple:
            raise TypeError('allreduce cannot handle tuple data')

        if msgtype.is_tuple:
            raise TypeError('allreduce cannot handle tuple data')

        xp = chainer.cuda.get_array_module(x)

        # TODO(kuenishi): do we check all messages have same shape and dims?

        # Source buffer
        sbuf = _memory_utility.array_to_buffer_object(
            x, _get_mpi_type(msgtype))
        # Destination buffer
        dbuf = xp.empty([numpy.prod(msgtype.shapes[0])], dtype=msgtype.dtype)
        dbuf = _memory_utility.array_to_buffer_object(
            dbuf, _get_mpi_type(msgtype))
        self.mpi_comm.Allreduce(sbuf, dbuf)

        return dbuf.reshape(msgtype.shapes[0])
예제 #3
0
    def recv(self, source, tag):
        """A primitive of inter-process receiver.

        This method tries to receive numpy-array from target process.
        The target process is expected to invoke ``send()``.
        This method relies on mpi4py fast communication optimized for
        numpy arrays, which discards any information attached to
        chainer.Variable objects. Please be sure.

        If the corresponding ``send()`` is invoked with cupy array,
        the returned array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            source (int): Target process specifier.
            tag (int): Message ID (MPI feature).

        Returns:
            data (tuple of numpy/cupy array or numpy/cupy array):
                Received data. If ``send()`` is invoked with tuple data,
                it is also tuple. Otherwise, it is a vanilla numpy/cupy array.
        """

        chainer.utils.experimental(
            'chainermn.communicators.MpiCommunicatorBase.recv')

        msgtype = self.mpi_comm.recv(source=source, tag=tag)
        xp = msgtype.get_array_module()

        if msgtype.is_tuple:
            msg = []
            for shape in msgtype.shapes:
                buf = xp.empty(
                    [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype)
                rtype = _get_mpi_type(msgtype)
                self.mpi_comm.Recv(
                    _memory_utility.array_to_buffer_object(buf, rtype),
                    source=source, tag=tag)
                msg.append(buf.reshape(shape))
            return tuple(msg)

        else:
            assert len(msgtype.shapes) == 1
            shape = msgtype.shapes[0]
            buf = xp.empty(
                [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype)
            rtype = _get_mpi_type(msgtype)
            self.mpi_comm.Recv(
                _memory_utility.array_to_buffer_object(buf, rtype),
                source=source, tag=tag)
            return buf.reshape(shape)
예제 #4
0
    def bcast(self, x, root=0):
        """A primitive of inter-process broadcast communication.

        This method tries to invoke broadcast communication within the
        communicator. All processes in the communicator are expected to
        invoke ``broadcast()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        If ``bcast()`` is invoked with cupy array in the root process,
        the returned array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            x (numpy/cupy array): Array to be broadcasted.
            root (int): Rank of root process.

        Returns:
            ys (tuple of numpy/cupy array): Received arrays.
        """
        chainer.utils.experimental(
            'chainermn.communicators.MpiCommunicatorBase.bcast')

        is_master = self.mpi_comm.rank == root

        if is_master:
            msgtype = _MessageType(x)
            _check_dtype('bcast', msgtype)

            if msgtype.is_tuple:
                raise TypeError('Tuple data cannot be broadcasted')

            msgtype = self.mpi_comm.bcast(msgtype, root)
            shape = msgtype.shapes[0]
            buf = _memory_utility.array_to_buffer_object(
                x, _get_mpi_type(msgtype))
            self.mpi_comm.Bcast(buf, root)
            return x
        else:
            msgtype = self.mpi_comm.bcast(None, root)
            xp = msgtype.get_array_module()
            shape = msgtype.shapes[0]
            buf = xp.empty(
                [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype)
            buftype = _get_mpi_type(msgtype)
            self.mpi_comm.Bcast(
                _memory_utility.array_to_buffer_object(buf, buftype),
                root)
            return buf.reshape(shape)
예제 #5
0
    def allgather(self, x):
        chainer.utils.experimental(
            'chainermn.communicators.MpiCommunicatorBase.allgather')

        msgtype = _MessageType(x)
        _check_dtype('allgather', msgtype)

        msgtypes = self.mpi_comm.allgather(msgtype)
        _check_dtypes_are_same(msgtypes)

        # Type check.
        for msgtype in msgtypes:
            if msgtype.is_tuple:
                raise TypeError('allgather cannot handle tuple data')

            assert len(msgtype.shapes) == 1

        # Collective communication.
        xp = chainer.backend.get_array_module(x)
        shapes = [msgtype.shapes[0] for msgtype in msgtypes]
        sbuf = _memory_utility.array_to_buffer_object(
            x, _get_mpi_type(msgtype))
        rlens = [chainer.utils.size_of_shape(s) for s in shapes]
        rbuf = xp.empty([sum(rlens)], dtype=msgtype.dtype)
        if xp is not numpy:
            chainer.cuda.Stream.null.synchronize()
        self.mpi_comm.Allgatherv(
            sbuf,
            [_memory_utility.get_device_memory_pointer(rbuf),
             (rlens, _cnt_to_dsp(rlens)), _get_mpi_type(msgtype)])
        ys = [rbuf[i:i + l].reshape(s)
              for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)]

        return tuple(ys)
예제 #6
0
 def allreduce_grad(self, model):
     for param in _memory_utility.extract_params_set_grad(model):
         grad = param.grad
         is_float16 = param.grad.dtype == np.float16
         if is_float16:
             grad = grad.astype(np.float32)
         buf = _memory_utility.array_to_buffer_object(grad)
         self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf)
         if is_float16:
             param.grad = grad.astype(np.float16)
         param.grad /= self.size
예제 #7
0
 def bcast_data(self, model):
     for _, param in sorted(model.namedparams()):
         if param.data is not None:
             data = param.data
             is_float16 = param.data.dtype == numpy.float16
             if is_float16:
                 data = data.astype(numpy.float32)
             buf = _memory_utility.array_to_buffer_object(data)
             self.mpi_comm.Bcast(buf)
             if is_float16:
                 param.data = data.astype(numpy.float16)
예제 #8
0
    def __call__(self, trainer=None):
        # We need to delay MPI4py import. Please also note that _memory_utility
        # module also imports MPI4py.
        from chainermn.communicators._memory_utility \
            import array_to_buffer_object
        import mpi4py.MPI

        for _, param in sorted(_namedpersistents(self.model)):
            if hasattr(param, 'dtype') and param.dtype == np.float32:
                buf = array_to_buffer_object(param)
                self.comm.Allreduce(mpi4py.MPI.IN_PLACE, buf)
                param /= self.comm.size
            else:
                pass  # Integer persistent variables are ignored
예제 #9
0
 def bcast_data(self, model):
     for _, param in sorted(model.namedparams()):
         buf = _memory_utility.array_to_buffer_object(param.data)
         self.mpi_comm.Bcast(buf)
예제 #10
0
    def gather(self, x, root=0):
        """A primitive of inter-process gather communication.

        This method tries to invoke gather communication within the
        communicator. All processes in the communicator are expected to
        invoke ``gather()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        If ``x`` is numpy array, the received data will also be allocated
        as numpy array. Additionally, when ``x`` is cupy array, the returned
        array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            x (numpy/cupy array): Array to be gathered.
            root (int): Rank of root process.

        Returns:
            ys (tuple of numpy/cupy array):
                Received arrays. ``None`` for non-root processes.
        """
        chainer.utils.experimental(
            'chainermn.communicators.MpiCommunicatorBase.gather')

        is_master = self.mpi_comm.rank == root

        msgtype = _MessageType(x)
        _check_dtype('gather', msgtype)

        msgtypes = self.mpi_comm.gather(msgtype, root)

        if is_master:
            _check_dtypes_are_same(msgtypes)

            for msgtype in msgtypes:
                if msgtype.is_tuple:
                    raise TypeError('gather cannot handle tuple data')

                assert len(msgtype.shapes) == 1

            xp = chainer.backend.get_array_module(x)
            sbuf = _memory_utility.array_to_buffer_object(
                x, _get_mpi_type(msgtype))
            shapes = [mty.shapes[0] for mty in msgtypes]
            rlens = [chainer.utils.size_of_shape(s) for s in shapes]
            rbuf = xp.empty([sum(rlens)], dtype=msgtype.dtype)

            if xp is not numpy:
                chainer.cuda.Stream.null.synchronize()

            self.mpi_comm.Gatherv(
                sbuf,
                [_memory_utility.get_device_memory_pointer(rbuf),
                 (rlens, _cnt_to_dsp(rlens)), _get_mpi_type(msgtype)],
                root)

            ys = [rbuf[i:i + l].reshape(s)
                  for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)]
            return tuple(ys)

        else:
            sbuf = _memory_utility.array_to_buffer_object(
                x, _get_mpi_type(msgtype))
            self.mpi_comm.Gatherv(sbuf, None, root)
            return None
def broadcast_naive(mpi_comm, model):
    for _, param in sorted(model.namedparams()):
        buf = _memory_utility.array_to_buffer_object(param.data)
        mpi_comm.Bcast(buf)
예제 #12
0
 def allreduce_grad(self, model):
     for _, param in sorted(model.namedparams()):
         buf = _memory_utility.array_to_buffer_object(param.grad)
         self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf)
         param.grad /= self.size
예제 #13
0
 def allreduce_grad(self, model):
     for param in _memory_utility.extract_params(model):
         buf = _memory_utility.array_to_buffer_object(param.grad)
         self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf)
         param.grad /= self.size
예제 #14
0
 def allreduce_grad(self, model):
     for param in _memory_utility.extract_params_set_grad(model):
         buf = _memory_utility.array_to_buffer_object(param.grad)
         self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf)
         param.grad /= self.size
예제 #15
0
    def gather(self, x, root=0):
        """A primitive of inter-process gather communication.

        This method tries to invoke gather communication within the
        communicator. All processes in the communicator are expected to
        invoke ``gather()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        If ``x`` is numpy array, the received data will also be allocated
        as numpy array. Additionally, when ``x`` is cupy array, the returned
        array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            x (numpy/cupy array): Array to be gathered.
            root (int): Rank of root process.

        Returns:
            ys (tuple of numpy/cupy array):
                Received arrays. ``None`` for non-root processes.
        """
        chainer.utils.experimental(
            'chainermn.communicators.MpiCommunicatorBase.gather')

        is_master = self.mpi_comm.rank == root

        msgtype = _MessageType(x)
        _check_dtype('gather', msgtype)

        msgtypes = self.mpi_comm.gather(msgtype, root)

        if is_master:
            _check_dtypes_are_same(msgtypes)

            for msgtype in msgtypes:
                if msgtype.is_tuple:
                    raise TypeError('gather cannot handle tuple data')

                assert len(msgtype.shapes) == 1

            xp = chainer.backend.get_array_module(x)
            sbuf = _memory_utility.array_to_buffer_object(
                x, _get_mpi_type(msgtype))
            shapes = [mty.shapes[0] for mty in msgtypes]
            rlens = [numpy.prod(s, dtype=int) for s in shapes]
            rbuf = xp.empty([sum(rlens)], dtype=msgtype.dtype)

            if xp is not numpy:
                chainer.cuda.Stream.null.synchronize()

            self.mpi_comm.Gatherv(sbuf, [
                _memory_utility.get_device_memory_pointer(rbuf),
                (rlens, _cnt_to_dsp(rlens)),
                _get_mpi_type(msgtype)
            ], root)

            ys = [
                rbuf[i:i + l].reshape(s)
                for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)
            ]
            return tuple(ys)

        else:
            sbuf = _memory_utility.array_to_buffer_object(
                x, _get_mpi_type(msgtype))
            self.mpi_comm.Gatherv(sbuf, None, root)
            return None
예제 #16
0
    def scatter(self, xs, root=0):
        """A primitive of inter-process scatter communication.

        This method tries to invoke scatter communication within the
        communicator. All processes in the communicator are expected to
        invoke ``scatter()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        If ``xs`` is tuple, each element is send to different processes.
        The length of the tuple must be the same as the communicator size.
        If ``xs`` is ``numpy.ndarrray``, it is splitted with the first
        axis and sent to different processes. For slave processes, ``xs``
        is allowed to be any value (will be ignored).

        If ``scatter()`` is invoked with cupy array in the root process,
        the returned array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            xs (tuple of numpy/cupy array): Arrays to be scattered.
            root (int): Rank of root process.

        Returns:
            ys (numpy/cupy array): Received arrays.
        """
        chainer.utils.experimental(
            'chainermn.communicators.CommunicatorBase.scatter')

        is_master = self.mpi_comm.rank == root

        if is_master:
            # Type check.
            msgtype = _MessageType(xs)
            _check_dtype('scatter', msgtype)

            if msgtype.is_tuple:
                if len(msgtype.shapes) != self.size:
                    raise ValueError(
                        'the length of xs must be consistent '
                        'with communicator size')

                xp = chainer.backend.get_array_module(*xs)
                msgtype = tuple([_MessageType(x) for x in xs])
                shapes = [mty.shapes[0] for mty in msgtype]
                # concatenate([x.reshape(-1) ... ], axis=0) will fail
                xs = xp.concatenate([x.reshape(1, -1) for x in xs], axis=1)

            else:
                assert len(msgtype.shapes) == 1

                if msgtype.shapes[0][0] != self.mpi_comm.size:
                    raise ValueError(
                        'scatter received inconsistent number of inputs '
                        'with communicator size')

                xp = chainer.backend.get_array_module(xs)
                msgtype = tuple([_MessageType(xs[0])
                                 for _ in range(self.size)])
                shapes = [xs.shape[1:] for _ in range(self.size)]

            msgtype = self.mpi_comm.scatter(msgtype, root)
            shape = msgtype.shapes[0]

            # Collective communication.
            slens = [chainer.utils.size_of_shape(s) for s in shapes]
            sbuf = _memory_utility.get_device_memory_pointer(xs)
            rbuf = xp.empty(
                [chainer.utils.size_of_shape(shape)], dtype=msgtype.dtype)
            rtype = _get_mpi_type(msgtype)
            if xp is not numpy:
                chainer.cuda.Stream.null.synchronize()

            self.mpi_comm.Scatterv(
                [sbuf, (slens, _cnt_to_dsp(slens)), _get_mpi_type(msgtype)],
                _memory_utility.array_to_buffer_object(rbuf, rtype), root)

            return rbuf.reshape(shape)

        else:  # slave processes
            msgtypes = self.mpi_comm.scatter(None, root)
            xp = msgtypes.get_array_module()
            shape = msgtypes.shapes[0]
            rbuf = xp.empty(
                [chainer.utils.size_of_shape(shape)], dtype=msgtypes.dtype)
            rtype = _get_mpi_type(msgtypes)
            self.mpi_comm.Scatterv(
                None,
                _memory_utility.array_to_buffer_object(rbuf, rtype),
                root)
            return rbuf.reshape(shape)
예제 #17
0
    def recv(self, source, tag):
        """A primitive of inter-process receiver.

        This method tries to receive numpy-array from target process.
        The target process is expected to invoke ``send()``.
        This method relies on mpi4py fast communication optimized for
        numpy arrays, which discards any information attached to
        chainer.Variable objects. Please be sure.

        If the corresponding ``send()`` is invoked with cupy array,
        the returned array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            source (int): Target process specifier.
            tag (int): Message ID (MPI feature).

        Returns:
            data (tuple of numpy/cupy array or numpy/cupy array):
                Received data. If ``send()`` is invoked with tuple data,
                it is also tuple. Otherwise, it is a vanilla numpy/cupy array.
        """

        chainer.utils.experimental(
            'chainermn.communicators.MpiCommunicatorBase.recv')

        msgtype = self.mpi_comm.recv(source=source, tag=tag)
        xp = msgtype.get_array_module()
        if numpy.float16 == msgtype.dtype:
            comm_dtype = numpy.float32
        else:
            comm_dtype = msgtype.dtype

        if msgtype.is_tuple:
            msg = []
            for shape in msgtype.shapes:
                buf = xp.empty(
                    [chainer.utils.size_of_shape(shape)], dtype=comm_dtype)
                rtype = _get_mpi_type(msgtype)
                self.mpi_comm.Recv(
                    _memory_utility.array_to_buffer_object(buf, rtype),
                    source=source, tag=tag)

                if numpy.float16 == msgtype.dtype:
                    buf = buf.astype(numpy.float16)
                msg.append(buf.reshape(shape))
            return tuple(msg)

        else:
            assert len(msgtype.shapes) == 1
            shape = msgtype.shapes[0]
            buf = xp.empty([chainer.utils.size_of_shape(shape)],
                           dtype=comm_dtype)
            rtype = _get_mpi_type(msgtype)
            self.mpi_comm.Recv(
                _memory_utility.array_to_buffer_object(buf, rtype),
                source=source, tag=tag)

            if numpy.float16 == msgtype.dtype:
                buf = buf.astype(numpy.float16)
            return buf.reshape(shape)
예제 #18
0
 def bcast_data(self, model):
     for _, param in sorted(model.namedparams()):
         if param.data is not None:
             buf = _memory_utility.array_to_buffer_object(param.data)
             self.mpi_comm.Bcast(buf)
예제 #19
0
    def alltoall(self, xs):
        """A primitive of inter-process all-to-all function.

        This method tries to invoke all-to-all communication within the
        communicator. All processes in the communicator are expected to
        invoke ``alltoall()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        Args:
            xs (tuple of numpy.ndarray)

        Returns:
            ys (tuple of numpy.ndarray):
                Received arrays. The length of tuple equals to
                the communicator size.
        """
        chainer.utils.experimental(
            'chainermn.communicators.MpiCommunicatorBase.alltoall')

        if len(xs) != self.size:
            raise ValueError(
                'The length of data must be same as communicator size.')

        # Type check.
        for x in xs:
            if x.dtype != numpy.float32:
                raise ValueError(
                    'alltoall only support dtype == numpy.float32')

        # Mediate #axes of arrays.
        sndims = numpy.array([x.ndim for x in xs], dtype=numpy.int32)
        rndims = numpy.empty(self.size, dtype=numpy.int32)
        self.mpi_comm.Alltoall(
            [sndims, mpi4py.MPI.INT],
            [rndims, mpi4py.MPI.INT])

        # Arbitrate shapes of arrays.
        sshapes = numpy.hstack([x.shape for x in xs]).astype(numpy.int32)
        rshapes = numpy.empty(sum(rndims), dtype=numpy.int32)
        self.mpi_comm.Alltoallv(
            [sshapes, (sndims, _cnt_to_dsp(sndims)), mpi4py.MPI.INT],
            [rshapes, (rndims, _cnt_to_dsp(rndims)), mpi4py.MPI.INT])
        shapes = [rshapes[i:i + l]
                  for i, l in zip(_cnt_to_dsp(rndims), rndims)]

        # Collective communication.
        slens = [numpy.prod(x.shape) for x in xs]
        xp = chainer.cuda.get_array_module(xs[0])
        sbuf = xp.hstack([x.reshape(-1) for x in xs])
        rlens = [numpy.prod(s) for s in shapes]
        rbuf = numpy.empty(sum(rlens), dtype=numpy.float32)
        if xp is not numpy:
            sbuf = _memory_utility.array_to_buffer_object(sbuf)[0]
            chainer.cuda.Stream.null.synchronize()
        self.mpi_comm.Alltoallv(
            [sbuf, (slens, _cnt_to_dsp(slens)), mpi4py.MPI.FLOAT],
            [rbuf, (rlens, _cnt_to_dsp(rlens)), mpi4py.MPI.FLOAT])
        ys = [rbuf[i:i + l].reshape(s)
              for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)]

        return tuple(ys)
예제 #20
0
    def scatter(self, xs, root=0):
        """A primitive of inter-process scatter communication.

        This method tries to invoke scatter communication within the
        communicator. All processes in the communicator are expected to
        invoke ``scatter()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        If ``xs`` is tuple, each element is send to different processes.
        The length of the tuple must be the same as the communicator size.
        If ``xs`` is ``numpy.ndarrray``, it is splitted with the first
        axis and sent to different processes. For slave processes, ``xs``
        is allowed to be any value (will be ignored).

        If ``scatter()`` is invoked with cupy array in the root process,
        the returned array will be placed at current device
        (``https://docs-cupy.chainer.org/en/stable/tutorial/basic.html#current-device``)
        regardless of which device the argument is placed at remote nodes.

        Args:
            xs (tuple of numpy/cupy array): Arrays to be scattered.
            root (int): Rank of root process.

        Returns:
            ys (numpy/cupy array): Received arrays.
        """
        chainer.utils.experimental(
            'chainermn.communicators.CommunicatorBase.scatter')

        is_master = self.mpi_comm.rank == root

        if is_master:
            # Type check.
            msgtype = _MessageType(xs)
            _check_dtype('scatter', msgtype)

            if msgtype.is_tuple:
                if len(msgtype.shapes) != self.size:
                    raise ValueError('the length of xs must be consistent '
                                     'with communicator size')

                xp = chainer.backend.get_array_module(*xs)
                msgtype = tuple([_MessageType(x) for x in xs])
                shapes = [mty.shapes[0] for mty in msgtype]
                # concatenate([x.reshape(-1) ... ], axis=0) will fail
                xs = xp.concatenate([x.reshape(1, -1) for x in xs], axis=1)

            else:
                assert len(msgtype.shapes) == 1

                if msgtype.shapes[0][0] != self.mpi_comm.size:
                    raise ValueError(
                        'scatter received inconsistent number of inputs '
                        'with communicator size')

                xp = chainer.backend.get_array_module(xs)
                msgtype = tuple(
                    [_MessageType(xs[0]) for _ in range(self.size)])
                shapes = [xs.shape[1:] for _ in range(self.size)]

            msgtype = self.mpi_comm.scatter(msgtype, root)
            shape = msgtype.shapes[0]

            # Collective communication.
            slens = [numpy.prod(s, dtype=int) for s in shapes]
            sbuf = _memory_utility.get_device_memory_pointer(xs)
            rbuf = xp.empty([numpy.prod(shape, dtype=int)],
                            dtype=msgtype.dtype)
            rtype = _get_mpi_type(msgtype)
            if xp is not numpy:
                chainer.cuda.Stream.null.synchronize()

            self.mpi_comm.Scatterv(
                [sbuf, (slens, _cnt_to_dsp(slens)),
                 _get_mpi_type(msgtype)],
                _memory_utility.array_to_buffer_object(rbuf, rtype), root)

            return rbuf.reshape(shape)

        else:  # slave processes
            msgtypes = self.mpi_comm.scatter(None, root)
            xp = msgtypes.get_array_module()
            shape = msgtypes.shapes[0]
            rbuf = xp.empty([numpy.prod(shape, dtype=int)],
                            dtype=msgtypes.dtype)
            rtype = _get_mpi_type(msgtypes)
            self.mpi_comm.Scatterv(
                None, _memory_utility.array_to_buffer_object(rbuf, rtype),
                root)
            return rbuf.reshape(shape)
예제 #21
0
    def scatter(self, xs, root=0):
        """A primitive of inter-process scatter communication.

        This method tries to invoke scatter communication within the
        communicator. All processes in the communicator are expected to
        invoke ``scatter()``. This method relies on mpi4py fast communication
        optimized for numpy arrays, as well as ``send()`` and ``recv()``.

        If ``xs`` is tuple, each element is send to different processes.
        The length of the tuple must be the same as the communicator size.
        If ``xs`` is ``numpy.ndarrray``, it is splitted with the first
        axis and sent to different processes. For slave processes, ``xs``
        is allowed to be any value (will be ignored).

        Args:
            xs (tuple of numpy.array or numpy.array): Arrays to be scattered.
            root (int): Rank of root process.

        Returns:
            ys (numpy.ndarray): Received arrays.
        """
        chainer.utils.experimental(
            'chainermn.communicators.CommunicatorBase.scatter')

        is_master = self.mpi_comm.rank == root

        if is_master:
            # Type check.
            msgtype = _MessageType(xs)

            if msgtype.is_tuple:
                if xs[0].dtype != numpy.float32:
                    raise TypeError(
                        'scatter only support dtype == numpy.float32')

                if len(msgtype.shapes) != self.size:
                    raise ValueError('the length of xs must be consistent '
                                     'with communicator size')

                xp = chainer.cuda.get_array_module(*xs)
                msgtype = tuple([_MessageType(x) for x in xs])
                shapes = [mty.shapes[0] for mty in msgtype]
                # concatenate([x.reshape(-1) ... ], axis=0) will fail
                xs = xp.concatenate([x.reshape(1, -1) for x in xs], axis=1)

            else:
                assert len(msgtype.shapes) == 1

                if xs.dtype != numpy.float32:
                    raise TypeError(
                        'scatter only support dtype == numpy.float32')

                if msgtype.shapes[0][0] != self.mpi_comm.size:
                    raise ValueError(
                        'scatter received inconsistent number of inputs '
                        'with communicator size')

                xp = chainer.cuda.get_array_module(xs)
                msgtype = tuple(
                    [_MessageType(xs[0]) for _ in range(self.size)])
                shapes = [xs.shape[1:] for _ in range(self.size)]

            msgtype = self.mpi_comm.scatter(msgtype, root)
            shape = msgtype.shapes[0]

            # Collective communication.
            slens = [numpy.prod(s) for s in shapes]
            sbuf = _memory_utility.array_to_buffer_object(xs)[0]
            rbuf = numpy.empty(numpy.prod(shape), dtype=numpy.float32)
            if xp is not numpy:
                chainer.cuda.Stream.null.synchronize()

            self.mpi_comm.Scatterv(
                [sbuf,
                 (slens, _cnt_to_dsp(slens)), mpi4py.MPI.FLOAT], rbuf, root)

            return rbuf.reshape(shape)

        else:  # slave processes
            msgtypes = self.mpi_comm.scatter(None, root)
            shape = msgtypes.shapes[0]
            rbuf = numpy.empty(numpy.prod(shape), dtype=numpy.float32)
            self.mpi_comm.Scatterv(None, rbuf, root)
            return rbuf.reshape(shape)