def _transform(local_distribution, glb_flat): glb_strides = strides_from_shape(local_distribution.global_shape) local_strides = strides_from_shape(local_distribution.local_shape) glb_ndim_inds = ndim_from_flat(glb_flat, glb_strides) local_ind = local_distribution.local_from_global(glb_ndim_inds) local_flat = local_distribution.local_flat_from_local(local_ind) return local_flat
def global_flat_indices(dim_data): """ Return a list of tuples of indices into the flattened global array. Parameters ---------- dim_data: dimension dictionary. Returns ------- list of 2-tuples of ints. Each tuple is a (start, stop) interval into the flattened global array. All selected ranges comprise the indices for this dim_data's sub-array. """ # TODO: FIXME: can be optimized when the last dimension is 'n'. for dd in dim_data: if dd['dist_type'] == 'n': dd['start'] = 0 dd['stop'] = dd['size'] glb_shape = tuple(dd['size'] for dd in dim_data) glb_strides = strides_from_shape(glb_shape) ranges = [range(dd['start'], dd['stop']) for dd in dim_data[:-1]] start_ranges = ranges + [[dim_data[-1]['start']]] stop_ranges = ranges + [[dim_data[-1]['stop']]] def flatten(idx): return sum(a * b for (a, b) in zip(idx, glb_strides)) starts = map(flatten, product(*start_ranges)) stops = map(flatten, product(*stop_ranges)) intervals = zip(starts, stops) return condense(intervals)