Esempio n. 1
0
 def func(nglobal):
     d = np.array(comm.rank)
     with filter_comm(comm.rank < nglobal, comm) as newcomm:
         if newcomm is not None:
             newcomm.Allreduce(MPI.IN_PLACE, as_mpi(d))
     d = comm.bcast(d)
     n = min(comm.size, nglobal)
     assert d == n * (n - 1) // 2
Esempio n. 2
0
 def func(nglobal):
     d = np.array(comm.rank)
     with filter_comm(comm.rank < nglobal, comm) as newcomm:
         if newcomm is not None:
             newcomm.Allreduce(MPI.IN_PLACE, as_mpi(d))
     d = comm.bcast(d)
     n = min(comm.size, nglobal)
     assert d == n * (n - 1) // 2
Esempio n. 3
0
def diffTdiff(input, output, axis=0, scalar=1., comm=None):
    """
    Inplace discrete difference transpose times discrete difference
    """

    if not isinstance(input, np.ndarray):
        raise TypeError('Input array is not an ndarray.')

    if input.dtype != var.FLOAT_DTYPE:
        raise TypeError('The data type of the input array is not ' + \
                        str(var.FLOAT_DTYPE.type) + '.')

    if axis < 0:
        raise ValueError("Invalid negative axis '" + str(axis) + "'.")

    if comm is None:
        comm = MPI.COMM_WORLD

    scalar = np.asarray(scalar, var.FLOAT_DTYPE)
    ndim = input.ndim
    if ndim == 0:
        output.flat = 0
        return
    
    if axis >= ndim:
        raise ValueError("Invalid axis '" + str(axis) + "'. Expected value is" \
                         ' less than ' + str(ndim) + '.')

    inplace = input.__array_interface__['data'][0] == \
              output.__array_interface__['data'][0]

    if axis != 0 or comm.size == 1:
        if input.size == 0:
            return
        tmf.operators.difftdiff(input.ravel(), output.ravel(), ndim-axis,
                                np.asarray(input.T.shape), scalar, inplace)
        return

    if product(input.shape[1:]) == 0:
        return
    with filter_comm(input.shape[0] > 0, comm) as fcomm:
        if fcomm is not None:
            status = tmf.operators_mpi.difftdiff(input.ravel(), output.ravel(),
                ndim-axis, np.asarray(input.T.shape), scalar, inplace,
                fcomm.py2f())
            if status != 0: raise RuntimeError()