예제 #1
0
def sum_inplace_scalar(a):
    ar = _np.asarray(a)

    if _n_nodes > 1:
        MPI_py_comm.Allreduce(_MPI.IN_PLACE, ar.reshape(-1), op=_MPI.SUM)

    return ar
예제 #2
0
def sum_inplace_MPI(a):
    """
    Computes the elementwise sum of a numpy array over all MPI processes.

    Args:
        a (numpy.ndarray): The input array, which will be overwritten in place.
    """
    if _n_nodes > 1:
        MPI_py_comm.Allreduce(_MPI.IN_PLACE, a.reshape(-1), op=_MPI.SUM)

    return a
예제 #3
0
파일: mpi_stats.py 프로젝트: vlpap/netket
def total_size(a, axis=None):
    """
    Compute the total number of elements stored in the input array among all MPI processes.

    This function essentially returns MPI_sum_among_processes(a.size).

    Args:
        a: The input array.
        axis: If specified, only considers the total size of that axis.

    Returns:
        a.size or a.shape[axis], reduced among all MPI processes.
    """
    if axis is None:
        l_size = a.size
    else:
        l_size = a.shape[axis]

    if _n_nodes > 1:
        l_size = MPI_comm.allreduce(l_size, op=MPI.SUM)

    return l_size
예제 #4
0
파일: mpi_stats.py 프로젝트: vlpap/netket
def sum(a, axis=None, out=None, keepdims: bool = False):
    """
    Compute the sum along the specified axis and over MPI processes.

    Args:
        a: The input array
        axis: Axis or axes along which the mean is computed. The default (None) is to
              compute the mean of the flattened array.
        out: An optional pre-allocated array to fill with the result.
        keepdims: If True the output array will have the same number of dimensions as the input,
              with the reduced axes having length 1. (default=False)

    Returns:
        The array with reduced dimensions defined by axis. If out is not none, returns out.

    """
    # asarray is necessary for the axis=None case to work, as the MPI call requires a NumPy array
    out = _np.asarray(_np.sum(a, axis=axis, out=out, keepdims=keepdims))

    if _n_nodes > 1:
        MPI_comm.Allreduce(MPI.IN_PLACE, out.reshape(-1), op=MPI.SUM)

    return out