def get_partition(self: 'PList[T]') -> 'PList[SList[T]]': p_list = PList() p_list.__content = SList([self.__content]) p_list.__global_size = _NPROCS p_list.__local_size = 1 p_list.__start_index = _PID p_list.__distribution = [1 for _ in par.procs()] return p_list
def permute(self: 'PList[T]', bij: Callable[[int], int]) -> 'PList[T]': p_list = self.__get_shape() distr = Distribution(self.__distribution) new_indices = self.mapi(lambda i, x: distr.to_pid(bij(i), x) ).get_partition().map(_group_by) mapping = new_indices.__content[0] keys = mapping.keys() messages = [mapping[pid] if pid in keys else [] for pid in par.procs()] exchanged = SList(parimpl.COMM.alltoall(messages)).flatten() exchanged.sort() p_list.__content = exchanged.map(lambda pair: pair[1]) return p_list
def from_seq(sequence: Sequence[T]) -> 'PList[T]': p_list = PList() if _PID == 0: p_list.__content = SList(sequence) p_list.__distribution = [ len(sequence) if i == 0 else 0 for i in par.procs() ] from_root = _COMM.bcast(p_list.__distribution, 0) p_list.__distribution = Distribution(from_root) p_list.__local_size = p_list.__distribution[_PID] p_list.__global_size = p_list.__distribution[0] p_list.__start_index = SList(p_list.__distribution).scanl(add, 0)[_PID] return p_list
def scatter(self: 'PList[T]', pid: int) -> 'PList[T]': assert pid in par.procs() def select(index, a_list): if index == pid: return a_list return [] select_distr = Distribution([ size if index == pid else 0 for (index, size) in enumerate(self.distribution) ]) at_pid = self.get_partition().mapi(select).flatten(select_distr) distr = Distribution.balanced(at_pid.length()) return at_pid.distribute(distr)
def balanced(size: int) -> 'Distribution': distr = [parallel.local_size(pid, size) for pid in procs()] return Distribution(distr)
def gather(self: 'PList[T]', pid: int) -> 'PList[T]': assert pid in par.procs() d_list = [self.length() if i == pid else 0 for i in par.procs()] distr = Distribution(d_list) return self.distribute(distr)