def init(value_at: Callable[[int], T], size: int = _NPROCS) -> 'PList[T]': assert size >= 0 p_list = PList() p_list.__global_size = size p_list.__local_size = parimpl.local_size(_PID, size) distribution_list = [ parimpl.local_size(i, size) for i in range(0, _NPROCS) ] p_list.__distribution = Distribution(distribution_list) p_list.__start_index = SList(p_list.__distribution).scanl( lambda x, y: x + y, 0)[_PID] p_list.__content = SList([ value_at(i) for i in range(p_list.__start_index, p_list.__start_index + p_list.__local_size) ]) p_list.__distribution = [ parimpl.local_size(i, size) for i in range(0, _NPROCS) ] return p_list
def _mpi_scatter(input_list): comm = MPI.COMM_WORLD # pylint: disable=c-extension-no-member nprocs = comm.Get_size() input_size = len(input_list) local_sizes = list( map(lambda dst: local_size(dst, input_size), range(0, nprocs))) accumulated_sizes = scan(local_sizes, add, 0) bounds = zip(accumulated_sizes[0:nprocs], accumulated_sizes[1:]) to_scatter = map(lambda bound: input_list[bound[0]:bound[1]], bounds) scattered = comm.scatter(to_scatter, 0) return scattered
def balanced(size: int) -> 'Distribution': distr = [parallel.local_size(pid, size) for pid in procs()] return Distribution(distr)