def test_distribution_pid_of_index_with_empty_3(): # pylint: disable=missing-docstring value = random.randint(0, 100) distr = Distribution([0, 10, 0, 20]) res = distr.to_pid(10, value) exp = 3, (10, value) assert exp == res
def test_distribution_pid_of_index_3(): # pylint: disable=missing-docstring value = random.randint(0, 100) distr = Distribution([4, 12, 27, 18]) res = distr.to_pid(43, value) exp = 3, (43, value) assert exp == res
def permute(self: 'PList[T]', bij: Callable[[int], int]) -> 'PList[T]': p_list = self.__get_shape() distr = Distribution(self.__distribution) new_indices = self.mapi(lambda i, x: distr.to_pid(bij(i), x) ).get_partition().map(_group_by) mapping = new_indices.__content[0] keys = mapping.keys() messages = [mapping[pid] if pid in keys else [] for pid in par.procs()] exchanged = SList(parimpl.COMM.alltoall(messages)).flatten() exchanged.sort() p_list.__content = exchanged.map(lambda pair: pair[1]) return p_list
def __init__(self: 'PList[T]'): # pylint: disable=super-init-not-called self.__content: 'SList[T]' = SList([]) self.__global_size: int = 0 self.__local_size: int = 0 self.__start_index: int = 0 self.__distribution = Distribution([0 for _ in range(0, _NPROCS)])
def scatter(self: 'PList[T]', pid: int) -> 'PList[T]': assert pid in par.procs() def select(index, a_list): if index == pid: return a_list return [] select_distr = Distribution([ size if index == pid else 0 for (index, size) in enumerate(self.distribution) ]) at_pid = self.get_partition().mapi(select).flatten(select_distr) distr = Distribution.balanced(at_pid.length()) return at_pid.distribute(distr)
def test_distribution_is_valid(): # pylint: disable=missing-docstring tmp = parallel.NPROCS parallel.NPROCS = 4 distr = [4, 12, 27, 18] size = sum(distr) res = Distribution(distr).is_valid(size) parallel.NPROCS = tmp assert res
def scatter_range(self: 'PList[T]', rng) -> 'PList[T]': def select(index, value): if index in rng: return value return None def not_none(value): return value is not None selected = self.mapi(select).filter(not_none) distr = Distribution.balanced(selected.length()) return selected.distribute(distr)
def from_seq(sequence: Sequence[T]) -> 'PList[T]': p_list = PList() if _PID == 0: p_list.__content = SList(sequence) p_list.__distribution = [ len(sequence) if i == 0 else 0 for i in par.procs() ] from_root = _COMM.bcast(p_list.__distribution, 0) p_list.__distribution = Distribution(from_root) p_list.__local_size = p_list.__distribution[_PID] p_list.__global_size = p_list.__distribution[0] p_list.__start_index = SList(p_list.__distribution).scanl(add, 0)[_PID] return p_list
def init(value_at: Callable[[int], T], size: int = _NPROCS) -> 'PList[T]': assert size >= 0 p_list = PList() p_list.__global_size = size p_list.__local_size = parimpl.local_size(_PID, size) distribution_list = [ parimpl.local_size(i, size) for i in range(0, _NPROCS) ] p_list.__distribution = Distribution(distribution_list) p_list.__start_index = SList(p_list.__distribution).scanl( lambda x, y: x + y, 0)[_PID] p_list.__content = SList([ value_at(i) for i in range(p_list.__start_index, p_list.__start_index + p_list.__local_size) ]) p_list.__distribution = [ parimpl.local_size(i, size) for i in range(0, _NPROCS) ] return p_list
def distribute(self: 'PList[T]', target_distr: Distribution) -> 'PList[T]': assert Distribution.is_valid(target_distr, self.__global_size) source_distr = self.__distribution source_bounds = interval.bounds(source_distr) target_bounds = interval.bounds(target_distr) local_interval = source_bounds[_PID] bounds_to_send = target_bounds.map( lambda i: interval.intersection(i, local_interval)) msgs = [ interval.to_slice(self.__content, interval.shift(inter, -self.__start_index)) for inter in bounds_to_send ] slices = _COMM.alltoall(msgs) p_list = PList() p_list.__content = SList(slices).flatten() p_list.__local_size = target_distr[_PID] p_list.__global_size = self.__global_size p_list.__start_index = SList(target_distr).scanl(add, 0)[_PID] p_list.__distribution = target_distr return p_list
def gather(self: 'PList[T]', pid: int) -> 'PList[T]': assert pid in par.procs() d_list = [self.length() if i == pid else 0 for i in par.procs()] distr = Distribution(d_list) return self.distribute(distr)
def balance(self: 'PList[T]') -> 'PList[T]': return self.distribute(Distribution.balanced(self.length()))
def test_distribution_balanced_not_valid(): # pylint: disable=missing-docstring size = random.randint(10, 100) distr = Distribution.balanced(size) res = distr.is_valid(size - 1) assert not res
def test_distribution_is_not_valid(): # pylint: disable=missing-docstring distr = [4, 12, 27, 18] size = sum(distr) res = Distribution(distr).is_valid(size - 1) assert not res