def sum_inplace_jax(x): # if not isinstance(x, jax.interpreters.xla.DeviceArray): # raise TypeError("Argument to sum_inplace_jax must be a DeviceArray, got {}" # .format(type(x))) if _n_nodes == 1: return x # This below only works on cpus... # we should make this work for gpus too.. # TODO: unsafe_buffer_pointer is considered not yet definitive interface ptr = x.block_until_ready().device_buffer.unsafe_buffer_pointer() # The above is faster. # This below should work more often, but might copy. # Depending on future changes in jaxlib, we might have to switch to # this below. # see Google/jax #2123 and #1009 # _x = jax.xla._force(x.block_until_ready()) # ptr = _x.device_buffer.unsafe_buffer_pointer() # using native numpy because jax's numpy does not have ctypeslib data_pointer = _np.ctypeslib.ndpointer(x.dtype, shape=x.shape) # wrap jax data into a standard numpy array which is handled by MPI arr = data_pointer(ptr).contents _MPI_comm.Allreduce(_MPI.IN_PLACE, arr.reshape(-1), op=_MPI.SUM) return x
def sum_inplace_scalar(a): ar = _np.asarray(a) if _n_nodes > 1: _MPI_comm.Allreduce(_MPI.IN_PLACE, ar.reshape(-1), op=_MPI.SUM) return ar
def total_size(a, axis=None): if axis is None: l_size = a.size else: l_size = a.shape[axis] if _n_nodes > 1: l_size = MPI_comm.allreduce(l_size, op=MPI.SUM) return l_size
def sum_inplace_MPI(a): """ Computes the elementwise sum of a numpy array over all MPI processes. Args: a (numpy.ndarray): The input array, which will be overwritten in place. """ if _n_nodes > 1: _MPI_comm.Allreduce(_MPI.IN_PLACE, a.reshape(-1), op=_MPI.SUM) return a
def sum(a, axis=None, out=None): """ Compute the arithmetic mean along the specified axis and over MPI processes. """ # asarray is necessary for the axis=None case to work, as the MPI call requires a NumPy array out = _np.asarray(_np.sum(a, axis=axis, out=out)) if _n_nodes > 1: MPI_comm.Allreduce(MPI.IN_PLACE, out.reshape(-1), op=MPI.SUM) return out
def seed(seed=None): """ Seed the random number generator. Each MPI process is automatically assigned a different, process-dependent, sub-seed. Parameters: seed (int, optional): Seed for the randon number generator. """ with objmode(derived_seed="int64"): size = _n_nodes rank = _rank if rank == 0: _np.random.seed(seed) derived_seed = _np.random.randint(0, 1 << 32, size=size) else: derived_seed = None if _n_nodes > 1: derived_seed = _MPI_comm.scatter(derived_seed, root=0) _np.random.seed(derived_seed)
def total_size(a, axis=None): """ Compute the total number of elements stored in the input array among all MPI processes. This function essentially returns MPI_sum_among_processes(a.size). Args: a: The input array. axis: If specified, only considers the total size of that axis. Returns: a.size or a.shape[axis], reduced among all MPI processes. """ if axis is None: l_size = a.size else: l_size = a.shape[axis] if _n_nodes > 1: l_size = MPI_comm.allreduce(l_size, op=MPI.SUM) return l_size
def sum(a, axis=None, out=None, keepdims: bool = False): """ Compute the arithmetic mean along the specified axis and over MPI processes. Args: a: The input array axis: Axis or axes along which the mean is computed. The default (None) is to compute the mean of the flattened array. out: An optional pre-allocated array to fill with the result. keepdims: If True the output array will have the same number of dimensions as the input, with the reduced axes having length 1. (default=False) Returns: The array with reduced dimensions defined by axis. If out is not none, returns out. """ # asarray is necessary for the axis=None case to work, as the MPI call requires a NumPy array out = _np.asarray(_np.sum(a, axis=axis, out=out, keepdims=keepdims)) if _n_nodes > 1: MPI_comm.Allreduce(MPI.IN_PLACE, out.reshape(-1), op=MPI.SUM) return out