def test_sizes(ctx_with_gs_limits, gl_size, gs_is_multiple): """ Test that virtual sizes are correct. """ ctx = ctx_with_gs_limits grid_size, local_size = gl_size ref = ReferenceIds(grid_size, local_size, gs_is_multiple) get_sizes = ctx.compile_static(""" KERNEL void get_sizes(GLOBAL_MEM int *sizes) { if (virtual_global_flat_id() > 0) return; for (int i = 0; i < 3; i++) { sizes[i] = virtual_local_size(i); sizes[i + 3] = virtual_num_groups(i); sizes[i + 6] = virtual_global_size(i); } sizes[9] = virtual_global_flat_size(); } """, 'get_sizes', ref.global_size, local_size=ref.local_size) sizes = ctx.allocate(10, numpy.int32) get_sizes(sizes) gls = list(ref.global_size) + [1] * (3 - len(ref.global_size)) ls = list(ref.local_size) + [1] * (3 - len(ref.local_size)) gs = [min_blocks(g, l) for g, l in zip(gls, ls)] ref_sizes = numpy.array(ls + gs + gls + [product(gls)]).astype(numpy.int32) assert diff_is_negligible(sizes.get(), ref_sizes)
def predict_local_ids(self, dim): if dim > len(self.global_size) - 1: return numpy.zeros(self.np_global_size, dtype=numpy.int32) np_dim = len(self.global_size) - dim - 1 global_len = self.np_global_size[np_dim] local_len = self.np_local_size[np_dim] repetitions = min_blocks(global_len, local_len) pattern = numpy.tile(numpy.arange(local_len), repetitions)[:global_len] pattern_shape = [x if i == np_dim else 1 for i, x in enumerate(self.np_global_size)] pattern = pattern.reshape(*pattern_shape) tile_shape = [x if i != np_dim else 1 for i, x in enumerate(self.np_global_size)] pattern = numpy.tile(pattern, tile_shape) return pattern.astype(numpy.int32)