def test_barrier_counter_barriers(): knl = lp.make_kernel( "[n,m,ell] -> {[i,k,j]: 0<=i<50 and 1<=k<98 and 0<=j<10}", [ """ c[i,j,k] = 2*a[i,j,k] {id=first} e[i,j,k] = c[i,j,k+1]+c[i,j,k-1] {dep=first} """ ], [ lp.TemporaryVariable("c", lp.auto, shape=(50, 10, 99)), "..." ], name="weird2", ) knl = lp.add_and_infer_dtypes(knl, dict(a=np.int32)) knl = lp.split_iname(knl, "k", 128, inner_tag="l.0") sync_map = lp.get_synchronization_map(knl) print(sync_map) n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} barrier_count = sync_map["barrier_local"].eval_with_dict(params) assert barrier_count == 50*10*2
def test_all_counters_parallel_matmul(): knl = lp.make_kernel("{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", ["c[i, j] = sum(k, a[i, k]*b[k, j])"], name="matmul", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) knl = lp.split_iname(knl, "i", 16, outer_tag="g.0", inner_tag="l.1") knl = lp.split_iname(knl, "j", 16, outer_tag="g.1", inner_tag="l.0") knl = lp.split_iname(knl, "k", 16) knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"]) knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"]) n = 512 m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} sync_map = lp.get_synchronization_map(knl) assert len(sync_map) == 2 assert sync_map["kernel_launch"].eval_with_dict(params) == 1 assert sync_map["barrier_local"].eval_with_dict(params) == 2 * m / 16 op_map = lp.get_op_map(knl) f32mul = op_map[lp.Op(np.float32, 'mul')].eval_with_dict(params) f32add = op_map[lp.Op(np.float32, 'add')].eval_with_dict(params) i32ops = op_map[lp.Op(np.int32, 'add')].eval_with_dict(params) i32ops += op_map[lp.Op(np.dtype(np.int32), 'mul')].eval_with_dict(params) assert f32mul + f32add == n * m * l * 2 op_map = lp.get_mem_access_map(knl) f32coal = op_map[lp.MemAccess('global', np.float32, stride=1, direction='load', variable='b')].eval_with_dict(params) f32coal += op_map[lp.MemAccess('global', np.float32, stride=1, direction='load', variable='a')].eval_with_dict(params) assert f32coal == n * m + m * l f32coal = op_map[lp.MemAccess('global', np.float32, stride=1, direction='store', variable='c')].eval_with_dict(params) assert f32coal == n * l local_mem_map = lp.get_mem_access_map(knl).filter_by(mtype=['local']) local_mem_l = local_mem_map[lp.MemAccess( 'local', np.dtype(np.float32), direction='load')].eval_with_dict(params) assert local_mem_l == n * m * l * 2
def test_barrier_counter_nobarriers(): knl = lp.make_kernel( "[n,m,ell] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<ell}", [ """ c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k] e[i, k] = g[i,k]*h[i,k+1] """ ], name="basic", assumptions="n,m,ell >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) sync_map = lp.get_synchronization_map(knl) n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} assert len(sync_map) == 1 assert sync_map["kernel_launch"].eval_with_dict(params) == 1
def test_all_counters_parallel_matmul(): bsize = 16 knl = lp.make_kernel( "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<ell}", [ "c[i, j] = sum(k, a[i, k]*b[k, j])" ], name="matmul", assumptions="n,m,ell >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) knl = lp.split_iname(knl, "i", bsize, outer_tag="g.0", inner_tag="l.1") knl = lp.split_iname(knl, "j", bsize, outer_tag="g.1", inner_tag="l.0") knl = lp.split_iname(knl, "k", bsize) knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"], default_tag="l.auto") knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"], default_tag="l.auto") n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} group_size = bsize*bsize n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize) subgroups_per_group = div_ceil(group_size, SGS) n_subgroups = n_workgroups*subgroups_per_group sync_map = lp.get_synchronization_map(knl) assert len(sync_map) == 2 assert sync_map["kernel_launch"].eval_with_dict(params) == 1 assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) f32mul = op_map[ lp.Op(np.float32, 'mul', CG.SUBGROUP) ].eval_with_dict(params) f32add = op_map[ lp.Op(np.float32, 'add', CG.SUBGROUP) ].eval_with_dict(params) i32ops = op_map[ lp.Op(np.int32, 'add', CG.SUBGROUP) ].eval_with_dict(params) i32ops += op_map[ lp.Op(np.dtype(np.int32), 'mul', CG.SUBGROUP) ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert f32mul+f32add == m*2*n_subgroups mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=SGS) f32s1lb = mem_access_map[lp.MemAccess('global', np.float32, lid_strides={0: 1, 1: Variable('ell')}, gid_strides={1: bsize}, direction='load', variable='b', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32s1la = mem_access_map[lp.MemAccess('global', np.float32, lid_strides={0: 1, 1: Variable('m')}, gid_strides={0: Variable('m')*bsize}, direction='load', variable='a', count_granularity=CG.WORKITEM) ].eval_with_dict(params) assert f32s1lb == n*m*ell/bsize assert f32s1la == n*m*ell/bsize f32coal = mem_access_map[lp.MemAccess('global', np.float32, lid_strides={0: 1, 1: Variable('ell')}, gid_strides={0: Variable('ell')*bsize, 1: bsize}, direction='store', variable='c', count_granularity=CG.WORKITEM) ].eval_with_dict(params) assert f32coal == n*ell local_mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=SGS).filter_by(mtype=['local']) local_mem_l = local_mem_map.filter_by(direction=['load'] ).eval_and_sum(params) # (count-per-sub-group)*n_subgroups assert local_mem_l == m*2*n_subgroups local_mem_l_a = local_mem_map[lp.MemAccess('local', np.dtype(np.float32), direction='load', lid_strides={1: 16}, gid_strides={}, variable='a_fetch', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) local_mem_l_b = local_mem_map[lp.MemAccess('local', np.dtype(np.float32), direction='load', lid_strides={0: 1}, gid_strides={}, variable='b_fetch', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) # (count-per-sub-group)*n_subgroups assert local_mem_l_a == local_mem_l_b == m*n_subgroups local_mem_s = local_mem_map.filter_by(direction=['store'] ).eval_and_sum(params) # (count-per-sub-group)*n_subgroups assert local_mem_s == m*2/bsize*n_subgroups