def test_estimate_flop(): N = 512 A, B, C = matmul_auto_scheduler_test(N, N, N) dag = auto_scheduler.ComputeDAG([A, B, C]) assert abs(dag.flop_ct - 2 * N**3) < 0.5 D = topi.nn.relu(C) dag = auto_scheduler.ComputeDAG([A, B, D]) assert abs(dag.flop_ct - (2 * N**3 + N * N)) < 0.5 # should not count the comparison operations in padding E = topi.nn.pad(C, [1, 1]) dag = auto_scheduler.ComputeDAG([A, B, E]) assert abs(dag.flop_ct - 2 * N**3) < 0.5 F = te.compute((N, N), lambda i, j: E[i, j], name="F", attrs={"FLOP": 1234}) dag = auto_scheduler.ComputeDAG([A, B, F]) assert abs(dag.flop_ct - (2 * N**3 + 1234)) < 0.5 A = te.placeholder((N, N), dtype="float32", name="A") F = te.compute((N, N), lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j], 0)) dag = auto_scheduler.ComputeDAG([A, F]) assert abs(dag.flop_ct - N**2) < 0.5
def test_cpu_matmul(): dag = auto_scheduler.ComputeDAG(matmul_auto_scheduler_test(512, 512, 512)) s = dag.get_init_state() C = s.stage_ops[2] i, j, k = s[C].iters io, ii = s.split(C, i, [16]) jo, ji = s.split(C, j, [8]) s.reorder(C, [io, jo, k, ji, ii]) s.vectorize(C, ji) s.parallel(C, io) s.parallel(C, jo) s.unroll(C, k) target = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(compute_dag=dag, workload_key="test", target=target) names = auto_scheduler.feature.get_per_store_feature_names() fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0] stage_0 = fea[0] assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) fea_dict = {} for name, value in zip(names, stage_0): fea_dict[name] = value for name in ["B0", "B1", "B2"]: if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0): c_name = name if fequal(fea_dict[name + ".acc_type.kRead"], 1.0): if fequal(fea_dict[name + ".stride"], 0.0): b_name = name else: a_name = name """ lowered IR: Placeholder: A, B parallel i.0 (0,32) parallel j.0 (0,64) unroll k (0,512) vectorize j.1 (0,8) for i.1 (0,16) C...] = A[...] * B[...] """ # check touched memory in bytes, touched unique memory in bytes, reuse distance, etc. assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) # check annotations assert fequal(fea_dict["unroll_num"], math.log2(1 + 1)) # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0) assert fequal(fea_dict["vec_num"], math.log2(1 + 1)) assert fequal(fea_dict["parallel_num"], math.log2(2 + 1)) assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1))
def test_rfactor(): A, B, C = matmul_auto_scheduler_test(8, 8, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() ko, ki = s0.split(C, s0[C].iters[2], [16]) s1 = s0.copy() C_r = s1.rfactor(C, ko, 2) """ Placeholder: A, B for i (0,8) for j (0,8) for k_o (0,32) for k_i (0,16) C.rf = ... for ax0 (0,8) for ax1 (0,8) for k_o_v (0,32) C.repl = ... """ assert s1[C_r].iters[0].range.extent == 8 assert s1[C_r].iters[1].range.extent == 8 assert s1[C_r].iters[2].range.extent == 32 assert s1[C_r].iters[3].range.extent == 16 assert s1[C].iters[0].range.extent == 8 assert s1[C].iters[1].range.extent == 8 assert s1[C].iters[2].range.extent == 32 s2 = s0.copy() C_r = s2.rfactor(C, ki, 2) """ Placeholder: A, B for i (0,8) for j (0,8) for k_i (0,16) for k_o (0,32) C.rf = ... for ax0 (0,8) for ax1 (0,8) for k_i_v (0,16) C.repl = ... """ assert s2[C_r].iters[0].range.extent == 8 assert s2[C_r].iters[1].range.extent == 8 assert s2[C_r].iters[2].range.extent == 16 assert s2[C_r].iters[3].range.extent == 32 assert s2[C].iters[0].range.extent == 8 assert s2[C].iters[1].range.extent == 8 assert s2[C].iters[2].range.extent == 16
def test_apply_steps_with_layout_rewrite_corner_case(): A, B, C = matmul_auto_scheduler_test(1, 1, 1) dag = auto_scheduler.ComputeDAG([A, B, C]) s = dag.get_init_state() s.compute_root(C) i_j_fused = s.fuse(C, [s[C].iters[0], s[C].iters[1]]) s.parallel(C, i_j_fused) _, bufs = dag.apply_steps_from_state( s, layout_rewrite=auto_scheduler.LayoutRewriteOption. REWRITE_FOR_PRE_TRANSFORMED)
def test_follow_split_follow_fused_split(): A, B, C = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() C_global = s0.cache_write(C, "global") its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) split_step0 = len(s0.transform_steps) - 1 for level in range(1, 6): tmp = s0.copy() tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) for i in range(0, level): assert tmp[C].iters[i].range.extent == tmp[C_global].iters[ i].range.extent its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) split_step1 = len(s0.transform_steps) - 1 its = [] for i0, i1 in zip(its0, its1): its.append(i0) its.append(i1) s0.reorder(C, its) for i in range(0, 5): s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) for level in range(0, 4): tmp = s0.copy() tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, False) assert tmp[C].iters[ level + 1].range.extent == tmp[C_global].iters[0].range.extent for level in range(0, 4): tmp = s0.copy() tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, True) assert tmp[C].iters[ level + 1].range.extent == tmp[C_global].iters[1].range.extent
def test_split_fuse_reorder_annotation(): A, B, C = matmul_auto_scheduler_test(N=512, M=512, K=512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() i, j, k = s0[C].iters assert i.range.extent == 512 io, ii = s0.split(C, i, [16]) assert s0[C].iters[0] == io assert s0[C].iters[1] == ii assert io.range.extent == 32 assert ii.range.extent == 16 jo, ji = s0.split(C, j, [8]) assert jo.range.extent == 64 assert ji.range.extent == 8 s0.reorder(C, [io, jo, k, ji, ii]) assert s0[C].iters[2].range.extent == 512 fused_it = s0.fuse(C, [io, jo]) assert fused_it.range.extent == 2048 s1 = dag.get_init_state() i, j, _ = s1[C].iters i1, i2, i3 = s1.split(C, i, [8, 2]) j1, j2, j3 = s1.split(C, j, [32, 8], False) assert s1[C].iters[0].range.extent == 32 assert s1[C].iters[1].range.extent == 8 assert s1[C].iters[2].range.extent == 2 assert s1[C].iters[3].range.extent == 32 assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2 res = s1.bind(C, i1, "blockIdx.x") assert res == s1[C].iters[0] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE[ "blockIdx.x"] res = s1.bind(C, i2, "vthread") assert res == s1[C].iters[1] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE[ "vthread"] res = s1.bind(C, i3, "threadIdx.y") assert res == s1[C].iters[2] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE[ "threadIdx.y"] res = s1.parallel(C, j1) assert res == s1[C].iters[3] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE[ "parallel"] res = s1.unroll(C, j2) assert res == s1[C].iters[4] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE[ "unroll"] res = s1.vectorize(C, j3) assert res == s1[C].iters[5] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE[ "vectorize"]