示例#1
0
    def sum_inplace_jax(x):
        # if not isinstance(x, jax.interpreters.xla.DeviceArray):
        #    raise TypeError("Argument to sum_inplace_jax must be a DeviceArray, got {}"
        #            .format(type(x)))

        if _n_nodes == 1:
            return x

        # This below only works on cpus...
        # we should make this work for gpus too..
        # TODO: unsafe_buffer_pointer is considered not yet definitive interface
        ptr = x.block_until_ready().device_buffer.unsafe_buffer_pointer()

        # The above is faster.
        # This below should work more often, but might copy.
        # Depending on future changes in jaxlib, we might have to switch to
        # this below.
        # see Google/jax #2123 and #1009
        # _x = jax.xla._force(x.block_until_ready())
        # ptr = _x.device_buffer.unsafe_buffer_pointer()

        # using native numpy because jax's numpy does not have ctypeslib
        data_pointer = _np.ctypeslib.ndpointer(x.dtype, shape=x.shape)

        # wrap jax data into a standard numpy array which is handled by MPI
        arr = data_pointer(ptr).contents
        _MPI_comm.Allreduce(_MPI.IN_PLACE, arr.reshape(-1), op=_MPI.SUM)

        return x
示例#2
0
def sum_inplace_scalar(a):
    ar = _np.asarray(a)

    if _n_nodes > 1:
        _MPI_comm.Allreduce(_MPI.IN_PLACE, ar.reshape(-1), op=_MPI.SUM)

    return ar
示例#3
0
def total_size(a, axis=None):
    if axis is None:
        l_size = a.size
    else:
        l_size = a.shape[axis]

    if _n_nodes > 1:
        l_size = MPI_comm.allreduce(l_size, op=MPI.SUM)

    return l_size
示例#4
0
def sum_inplace_MPI(a):
    """
    Computes the elementwise sum of a numpy array over all MPI processes.

    Args:
        a (numpy.ndarray): The input array, which will be overwritten in place.
    """
    if _n_nodes > 1:
        _MPI_comm.Allreduce(_MPI.IN_PLACE, a.reshape(-1), op=_MPI.SUM)

    return a
示例#5
0
def sum(a, axis=None, out=None):
    """
    Compute the arithmetic mean along the specified axis and over MPI processes.
    """
    # asarray is necessary for the axis=None case to work, as the MPI call requires a NumPy array
    out = _np.asarray(_np.sum(a, axis=axis, out=out))

    if _n_nodes > 1:
        MPI_comm.Allreduce(MPI.IN_PLACE, out.reshape(-1), op=MPI.SUM)

    return out
示例#6
0
def seed(seed=None):
    """ Seed the random number generator. Each MPI process is automatically assigned
        a different, process-dependent, sub-seed.

        Parameters:
                  seed (int, optional): Seed for the randon number generator.

    """
    with objmode(derived_seed="int64"):
        size = _n_nodes
        rank = _rank

        if rank == 0:
            _np.random.seed(seed)
            derived_seed = _np.random.randint(0, 1 << 32, size=size)
        else:
            derived_seed = None

        if _n_nodes > 1:
            derived_seed = _MPI_comm.scatter(derived_seed, root=0)

    _np.random.seed(derived_seed)
示例#7
0
def total_size(a, axis=None):
    """
    Compute the total number of elements stored in the input array among all MPI processes.

    This function essentially returns MPI_sum_among_processes(a.size).

    Args:
        a: The input array.
        axis: If specified, only considers the total size of that axis.

    Returns:
        a.size or a.shape[axis], reduced among all MPI processes.
    """
    if axis is None:
        l_size = a.size
    else:
        l_size = a.shape[axis]

    if _n_nodes > 1:
        l_size = MPI_comm.allreduce(l_size, op=MPI.SUM)

    return l_size
示例#8
0
def sum(a, axis=None, out=None, keepdims: bool = False):
    """
    Compute the arithmetic mean along the specified axis and over MPI processes.

    Args:
        a: The input array
        axis: Axis or axes along which the mean is computed. The default (None) is to
              compute the mean of the flattened array.
        out: An optional pre-allocated array to fill with the result.
        keepdims: If True the output array will have the same number of dimensions as the input,
              with the reduced axes having length 1. (default=False)

    Returns:
        The array with reduced dimensions defined by axis. If out is not none, returns out.

    """
    # asarray is necessary for the axis=None case to work, as the MPI call requires a NumPy array
    out = _np.asarray(_np.sum(a, axis=axis, out=out, keepdims=keepdims))

    if _n_nodes > 1:
        MPI_comm.Allreduce(MPI.IN_PLACE, out.reshape(-1), op=MPI.SUM)

    return out