def test_chunked_axis(comm): grid = Grid(comm, (2,)) dist = Distribution(grid, [ ChunkedAxis([0, 200, 400]) ], shape=(400,), strides=(1,)) local = np.arange(20)[None, :].copy() assert dist.local_shape() == (200,) # local_to_global global_ = dist.local_to_global(local) start = 200 * comm.Get_rank() assert np.all(global_[0, :] == np.arange(start, start + 20)) # global_to_rank global_ = np.arange(195, 205)[None, :].copy() assert np.all(dist.global_to_rank_coords(global_) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1])