def distribute_as(self, shape_or_dist): """ Redistributes this DistArray, returning a new DistArray with the same data and corresponding distribution. Parameters ---------- shape_or_dist : shape tuple or Distribution object. Distribution for the new DistArray. The new distribution must have the same number of items as this distarray. The global shape and targets may be different. If shape tuple, immediately converted to a Distribution object with default parameters. Returns ------- DistArray A new DistArray distributed according to `dist`. Note ---- Currently implemented for block and non-distributed maps only. """ dist = asdistribution(self.context, shape_or_dist) if (any(d not in ('b', 'n') for d in self.distribution.dist) or any(d not in ('b', 'n') for d in dist.dist)): msg = "Only block and non-distributed dimensions currently supported." raise NotImplementedError(msg) def _local_redistribute_same_shape(comm, plan, la_from, la_to): from distarray.localapi import redistribute redistribute(comm, plan, la_from, la_to) def _local_redistribute_general(comm, plan, la_from, la_to): from distarray.localapi import redistribute_general redistribute_general(comm, plan, la_from, la_to) source_size = self.global_size dest_size = reduce(operator.mul, dist.shape, 1) if self.distribution.shape == dist.shape: _local_redistribute = _local_redistribute_same_shape elif source_size == dest_size: _local_redistribute = _local_redistribute_general else: msg = ("Original size %d != new size %d," " and total size of new array must be unchanged.") raise ValueError(msg % (source_size, dest_size)) plan = self.distribution.get_redist_plan(dist) ubercomm, all_targets = self.distribution.comm_union(dist) result = DistArray(dist, dtype=self.dtype) self.context.apply(_local_redistribute, (ubercomm, plan, self.key, result.key), targets=all_targets) return result
def _local_rand_call(self, local_func_name, shape_or_dist, kwargs=None): kwargs = kwargs or {} def _local_call(comm, local_func_name, ddpr, kwargs): import distarray.localapi.random as local_random from distarray.localapi.maps import Distribution local_func = getattr(local_random, local_func_name) if len(ddpr): dim_data = ddpr[comm.Get_rank()] else: dim_data = () dist = Distribution(dim_data=dim_data, comm=comm) return proxyize(local_func(distribution=dist, **kwargs)) distribution = asdistribution(self.context, shape_or_dist) ddpr = distribution.get_dim_data_per_rank() args = (distribution.comm, local_func_name, ddpr, kwargs) da_key = self.context.apply(_local_call, args, targets=distribution.targets) return DistArray.from_localarrays(da_key[0], distribution=distribution)
def _create_local(self, local_call, shape_or_dist, dtype): """Creates LocalArrays with the method named in `local_call`.""" def create_local(local_call, ddpr, dtype, comm): from distarray.localapi.maps import Distribution if len(ddpr) == 0: dim_data = () else: dim_data = ddpr[comm.Get_rank()] local_call = eval(local_call) distribution = Distribution(comm=comm, dim_data=dim_data) rval = local_call(distribution=distribution, dtype=dtype) return proxyize(rval) distribution = asdistribution(self, shape_or_dist) ddpr = distribution.get_dim_data_per_rank() args = [local_call, ddpr, dtype, distribution.comm] da_key = self.apply(create_local, args=args, targets=distribution.targets)[0] return DistArray.from_localarrays(da_key, distribution=distribution, dtype=dtype)