예제 #1
0
 def get_partition(self: 'PList[T]') -> 'PList[SList[T]]':
     p_list = PList()
     p_list.__content = SList([self.__content])
     p_list.__global_size = _NPROCS
     p_list.__local_size = 1
     p_list.__start_index = _PID
     p_list.__distribution = [1 for _ in par.procs()]
     return p_list
예제 #2
0
 def permute(self: 'PList[T]', bij: Callable[[int], int]) -> 'PList[T]':
     p_list = self.__get_shape()
     distr = Distribution(self.__distribution)
     new_indices = self.mapi(lambda i, x: distr.to_pid(bij(i), x)
                             ).get_partition().map(_group_by)
     mapping = new_indices.__content[0]
     keys = mapping.keys()
     messages = [mapping[pid] if pid in keys else [] for pid in par.procs()]
     exchanged = SList(parimpl.COMM.alltoall(messages)).flatten()
     exchanged.sort()
     p_list.__content = exchanged.map(lambda pair: pair[1])
     return p_list
예제 #3
0
 def from_seq(sequence: Sequence[T]) -> 'PList[T]':
     p_list = PList()
     if _PID == 0:
         p_list.__content = SList(sequence)
         p_list.__distribution = [
             len(sequence) if i == 0 else 0 for i in par.procs()
         ]
     from_root = _COMM.bcast(p_list.__distribution, 0)
     p_list.__distribution = Distribution(from_root)
     p_list.__local_size = p_list.__distribution[_PID]
     p_list.__global_size = p_list.__distribution[0]
     p_list.__start_index = SList(p_list.__distribution).scanl(add, 0)[_PID]
     return p_list
예제 #4
0
    def scatter(self: 'PList[T]', pid: int) -> 'PList[T]':
        assert pid in par.procs()

        def select(index, a_list):
            if index == pid:
                return a_list
            return []

        select_distr = Distribution([
            size if index == pid else 0
            for (index, size) in enumerate(self.distribution)
        ])
        at_pid = self.get_partition().mapi(select).flatten(select_distr)

        distr = Distribution.balanced(at_pid.length())
        return at_pid.distribute(distr)
예제 #5
0
 def balanced(size: int) -> 'Distribution':
     distr = [parallel.local_size(pid, size) for pid in procs()]
     return Distribution(distr)
예제 #6
0
 def gather(self: 'PList[T]', pid: int) -> 'PList[T]':
     assert pid in par.procs()
     d_list = [self.length() if i == pid else 0 for i in par.procs()]
     distr = Distribution(d_list)
     return self.distribute(distr)