def test_estimate_flop(): N = 512 A, B, C = matmul_auto_scheduler_test(N, N, N) dag = auto_scheduler.ComputeDAG([A, B, C]) assert abs(dag.flop_ct - 2 * N**3) < 0.5 D = topi.nn.relu(C) dag = auto_scheduler.ComputeDAG([A, B, D]) assert abs(dag.flop_ct - (2 * N**3 + N * N)) < 0.5 # should not count the comparison operations in padding E = topi.nn.pad(C, [1, 1]) dag = auto_scheduler.ComputeDAG([A, B, E]) assert abs(dag.flop_ct - 2 * N**3) < 0.5 F = te.compute((N, N), lambda i, j: E[i, j], name="F", attrs={"FLOP": 1234}) dag = auto_scheduler.ComputeDAG([A, B, F]) assert abs(dag.flop_ct - (2 * N**3 + 1234)) < 0.5 A = te.placeholder((N, N), dtype="float32", name="A") F = te.compute((N, N), lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j], 0)) dag = auto_scheduler.ComputeDAG([A, F]) assert abs(dag.flop_ct - N**2) < 0.5
def test_estimate_flop(): N = 512 A, B, C = matmul_auto_scheduler_test(N, N, N) dag = auto_scheduler.ComputeDAG([A, B, C]) assert abs(dag.flop_ct - 2 * N**3) < 0.5 D = topi.nn.relu(C) dag = auto_scheduler.ComputeDAG([A, B, D]) assert abs(dag.flop_ct - 2 * N**3 - N * N) < 0.5 # should not count the comparison operations in padding D = topi.nn.pad(C, [1, 1]) dag = auto_scheduler.ComputeDAG([A, B, D]) assert abs(dag.flop_ct - 2 * N**3) < 0.5
def generate_sketches(workload_func, args, target, print_for_debug=False): workload_key = auto_scheduler.make_workload_key(workload_func, args) dag = auto_scheduler.ComputeDAG(workload_key) task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.create(target)) policy = auto_scheduler.SketchPolicy(task, verbose=0) return policy.generate_sketches(print_for_debug)
def search_common(workload=matmul_auto_scheduler_test, target="llvm", search_policy='empty', seed=random.randint(1, 1 << 30), runner='local', cost_model=auto_scheduler.RandomModel(), num_measure_trials=2, init_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) N = 128 workload_key = auto_scheduler.make_workload_key(workload, (N, N, N)) dag = auto_scheduler.ComputeDAG(workload_key) target = tvm.target.create(target) task = auto_scheduler.SearchTask(dag, workload_key, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name init_search_callbacks = init_search_callbacks or [] init_search_callbacks.append( auto_scheduler.PreloadMeasuredStates(log_file)) if search_policy == 'empty': search_policy = auto_scheduler.EmptyPolicy(task) elif search_policy == 'sketch': search_policy = auto_scheduler.SketchPolicy( task, init_search_callbacks=init_search_callbacks) tuning_options = auto_scheduler.TuningOptions( num_measure_trials=num_measure_trials, runner=runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]) sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, res = auto_scheduler.load_best(log_file, workload_key, target) print("==== Python Code ====") print(dag.print_python_code_from_state(inp.state)) try: print("==== Lowered Stmt ====") print(tvm.lower(sch, args, simple_mode=True)) mod = tvm.build(sch, args, target) ctx = tvm.context(str(target), 0) dtype = dag.tensors[0].dtype a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) mod(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5) print("==== Verification passed ====") except Exception: raise Exception("Error encountered with seed: %d" % (seed)) print()
def test_split_fuse_reorder(): A, B, C = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() i, j, k = s0[C].iters assert i.range.extent == 512 io, ii = s0.split(C, i, [16]) assert s0[C].iters[0] == io assert s0[C].iters[1] == ii assert io.range.extent == 32 assert ii.range.extent == 16 jo, ji = s0.split(C, j, [8]) assert jo.range.extent == 64 assert ji.range.extent == 8 s0.reorder(C, [io, jo, k, ji, ii]) assert s0[C].iters[2].range.extent == 512 fused_it = s0.fuse(C, [io, jo]) assert fused_it.range.extent == 2048 s1 = dag.get_init_state() i, j, _ = s1[C].iters i1, i2, i3 = s1.split(C, i, [8, 2]) j1, j2, j3 = s1.split(C, j, [32, 8], False) assert s1[C].iters[0].range.extent == 32 assert s1[C].iters[1].range.extent == 8 assert s1[C].iters[2].range.extent == 2 assert s1[C].iters[3].range.extent == 32 assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2
def test_record_compute_at_root_inline_cache_read_write(): if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name="A") AA = topi.nn.relu(A) B = te.placeholder((512, 512), name="B") k = te.reduce_axis((0, 512), name="k") C = te.compute((512, 512), lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]), name="C") dag = auto_scheduler.ComputeDAG([A, B, C]) s = dag.get_init_state() # Cache Write C_shared = s.cache_write(C, "shared") # Compute At s.compute_at(C_shared, C, s[C].iters[0]) # Cache Read B_global = s.cache_read(B, "global", [C_shared]) s.compute_at(B_global, C_shared, s[C_shared].iters[2]) # Compute Inline s.compute_inline(AA) # Compute Root s.compute_root(C_shared) record_common(dag, s)
def test_record_follow_split_follow_fused_split(): if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name="A") B = te.placeholder((512, 512), name="B") k = te.reduce_axis((0, 512), name="k") C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") D = topi.nn.relu(C) E = topi.nn.relu(D) dag = auto_scheduler.ComputeDAG([A, B, E]) s = dag.get_init_state() # Follow Split s.split(C, s[C].iters[0], [4, 2, 8, 4], True) split_step0 = len(s.transform_steps) - 1 s.follow_split(C, s[C].iters[5], split_step0, 4) # Follow Fused Split its0 = s.split(E, s[E].iters[0], [4, 2, 8, 4], True) split_step1 = len(s.transform_steps) - 1 its1 = s.split(E, s[E].iters[5], [2, 4, 2, 4], True) split_step2 = len(s.transform_steps) - 1 its = [] for i0, i1 in zip(its0, its1): its.append(i0) its.append(i1) for i in range(0, 5): s.fuse(E, [s[E].iters[i], s[E].iters[i + 1]]) s.follow_fused_split(D, s[D].iters[0], [split_step1, split_step2], 2, True) record_common(dag, s)
def test_record_split_reorder_fuse_annotation(): if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name="A") B = te.placeholder((512, 512), name="B") k = te.reduce_axis((0, 512), name="k") C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") dag = auto_scheduler.ComputeDAG([A, B, C]) s = dag.get_init_state() # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) # Reorder s.reorder( C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]] ) # Fuse s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]]) # Parallel s.parallel(C, s[C].iters[0]) # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here) s.bind(C, s[C].iters[1], "blockIdx.x") s.bind(C, s[C].iters[2], "threadIdx.z") s.bind(C, s[C].iters[3], "vthread") # Unroll s.unroll(C, s[C].iters[4]) # Vectorize s.vectorize(C, s[C].iters[6]) record_common(dag, s)
def test_record(): if not tvm.runtime.enabled("llvm"): return A = te.placeholder((512, 512), name='A') B = te.placeholder((512, 512), name='B') k = te.reduce_axis((0, 512), name='k') C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') D = topi.nn.relu(C) k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C') F = topi.nn.relu(E) dag = auto_scheduler.ComputeDAG([A, B, F]) s = dag.get_init_state() # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) # Reorder s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]) # Fuse s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]]) # Compute at s.split(F, s[F].iters[0], [2]) s.compute_at(E, F, s[F].iters[0]) # Compute inline s.compute_inline(D) # Compute root s.compute_root(D) # Parallel s.parallel(C, s[C].iters[0]) # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here) s.bind(C, s[C].iters[1], "blockIdx.x") s.bind(C, s[C].iters[2], "threadIdx.z") s.bind(C, s[C].iters[3], "vthread") # Unroll s.unroll(C, s[C].iters[4]) # Vectorize s.vectorize(C, s[C].iters[6]) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) inp = auto_scheduler.measure.MeasureInput(task, s) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) with tempfile.NamedTemporaryFile() as fp: auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) inputs, results = log_reader.read_lines() assert len(inputs) == 1 s1 = dag.infer_bound_from_state(s) s2 = dag.infer_bound_from_state(inputs[0].state) assert s1 == s2 assert not (s1 == dag.get_init_state())
def test_cpu_matmul(): dag = auto_scheduler.ComputeDAG(matmul_auto_scheduler_test(512, 512, 512)) s = dag.get_init_state() C = s.stage_ops[2] i, j, k = s[C].iters io, ii = s.split(C, i, [16]) jo, ji = s.split(C, j, [8]) s.reorder(C, [io, jo, k, ji, ii]) s.vectorize(C, ji) s.parallel(C, io) s.parallel(C, jo) s.unroll(C, k) target = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(compute_dag=dag, workload_key="test", target=target) names = auto_scheduler.feature.get_per_store_feature_names() fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0] stage_0 = fea[0] assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) fea_dict = {} for name, value in zip(names, stage_0): fea_dict[name] = value for name in ["B0", "B1", "B2"]: if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0): c_name = name if fequal(fea_dict[name + ".acc_type.kRead"], 1.0): if fequal(fea_dict[name + ".stride"], 0.0): b_name = name else: a_name = name """ lowered IR: Placeholder: A, B parallel i.0 (0,32) parallel j.0 (0,64) unroll k (0,512) vectorize j.1 (0,8) for i.1 (0,16) C...] = A[...] * B[...] """ # check touched memory in bytes, touched unique memory in bytes, reuse distance, etc. assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) # check annotations assert fequal(fea_dict["unroll_num"], math.log2(1 + 1)) # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0) assert fequal(fea_dict["vec_num"], math.log2(1 + 1)) assert fequal(fea_dict["parallel_num"], math.log2(2 + 1)) assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1))
def test_split_fuse_reorder_annotation(): A, B, C = matmul_auto_scheduler_test(N=512, M=512, K=512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() i, j, k = s0[C].iters assert i.range.extent == 512 io, ii = s0.split(C, i, [16]) assert s0[C].iters[0] == io assert s0[C].iters[1] == ii assert io.range.extent == 32 assert ii.range.extent == 16 jo, ji = s0.split(C, j, [8]) assert jo.range.extent == 64 assert ji.range.extent == 8 s0.reorder(C, [io, jo, k, ji, ii]) assert s0[C].iters[2].range.extent == 512 fused_it = s0.fuse(C, [io, jo]) assert fused_it.range.extent == 2048 s1 = dag.get_init_state() i, j, _ = s1[C].iters i1, i2, i3 = s1.split(C, i, [8, 2]) j1, j2, j3 = s1.split(C, j, [32, 8], False) assert s1[C].iters[0].range.extent == 32 assert s1[C].iters[1].range.extent == 8 assert s1[C].iters[2].range.extent == 2 assert s1[C].iters[3].range.extent == 32 assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2 res = s1.bind(C, i1, "blockIdx.x") assert res == s1[C].iters[0] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["blockIdx.x"] res = s1.bind(C, i2, "vthread") assert res == s1[C].iters[1] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vthread"] res = s1.bind(C, i3, "threadIdx.y") assert res == s1[C].iters[2] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["threadIdx.y"] res = s1.parallel(C, j1) assert res == s1[C].iters[3] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["parallel"] res = s1.unroll(C, j2) assert res == s1[C].iters[4] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["unroll"] res = s1.vectorize(C, j3) assert res == s1[C].iters[5] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
def test_compute_at_root_inline(): dag = auto_scheduler.ComputeDAG( conv2d_nchw_bn_relu_auto_scheduler_test( N=1, H=224, W=224, CI=3, CO=64, kernel_size=7, strides=2, padding=3 ) ) s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 conv = s0.stage_ops[3] # bias = 4 bias_add = s0.stage_ops[5] # bn_scale = 6 bn_mul = s0.stage_ops[7] # bn_offset = 8 bn_add = s0.stage_ops[9] relu = s0.stage_ops[10] s0.compute_inline(bn_add) assert s0[bn_add].compute_at == 1 s0.compute_inline(bn_mul) assert s0[bn_mul].compute_at == 1 s0.compute_inline(bias_add) assert s0[bias_add].compute_at == 1 assert s0[conv].iters[0].range.extent == 1 assert s0[conv].iters[1].range.extent == 64 assert s0[conv].iters[2].range.extent == 112 assert s0[conv].iters[3].range.extent == 112 assert s0[conv].iters[4].range.extent == 3 assert s0[conv].iters[5].range.extent == 7 assert s0[conv].iters[6].range.extent == 7 s0.compute_at(conv, relu, s0[relu].iters[2]) assert s0[conv].compute_at == 2 s0 = dag.infer_bound_from_state(s0) assert s0[conv].iters[0].range.extent == 1 assert s0[conv].iters[1].range.extent == 1 assert s0[conv].iters[2].range.extent == 1 assert s0[conv].iters[3].range.extent == 112 assert s0[conv].iters[4].range.extent == 3 assert s0[conv].iters[5].range.extent == 7 assert s0[conv].iters[6].range.extent == 7 s0.compute_root(bn_mul) assert s0[bn_mul].compute_at == 0 s0.compute_root(conv) assert s0[conv].compute_at == 0 s0 = dag.infer_bound_from_state(s0) assert s0[conv].iters[0].range.extent == 1 assert s0[conv].iters[1].range.extent == 64 assert s0[conv].iters[2].range.extent == 112 assert s0[conv].iters[3].range.extent == 112 assert s0[conv].iters[4].range.extent == 3 assert s0[conv].iters[5].range.extent == 7 assert s0[conv].iters[6].range.extent == 7
def test_invalid_compute_dag(): failed = False try: A, B = invalid_compute_definition() dag = auto_scheduler.ComputeDAG([A, B]) except tvm.TVMError as e: failed = True assert failed
def test_stage_order(): N = 512 A, B, C, D, E = parallel_matmul_auto_scheduler_test(N) sch = te.create_schedule([D.op, E.op]) (D_local, ) = sch.cache_write([D], "local") (E_local, ) = sch.cache_write([E], "local") sch.cache_read(A, "shared", [D_local]) sch.cache_read(B, "shared", [D_local]) sch.cache_read(A, "shared", [E_local]) sch.cache_read(C, "shared", [E_local]) dag = auto_scheduler.ComputeDAG(sch) stage_ops_1 = dag.get_init_state().stage_ops # 3 placeholder, 4 x.shared, 2 {D,E}.local, 2 {D,E} compute assert len(stage_ops_1) == 11 # Cache read stage should follow the source stage for idx, op in enumerate(stage_ops_1): if op.name == "A": assert (stage_ops_1[idx + 1].name == "A.d.shared" and stage_ops_1[idx + 2].name == "A.shared") elif op.name in ["B", "C"]: assert stage_ops_1[idx + 1].name == "%s.shared" % op.name # Apply the same schedule to Ansor state and it should have the same stage order dag = auto_scheduler.ComputeDAG([A, B, C, D, E]) state = dag.get_init_state() D_local = state.cache_write(D, "local") E_local = state.cache_write(E, "local") state.cache_read(A, "shared", [D_local]) state.cache_read(B, "shared", [D_local]) state.cache_read(A, "shared", [E_local]) state.cache_read(C, "shared", [E_local]) stage_ops_2 = state.stage_ops assert len(stage_ops_1) == len(stage_ops_2) # Cache read stage should follow the source stage for op1, op2 in zip(stage_ops_1, stage_ops_2): assert op1.name == op2.name
def get_tiled_matmul(): A, B, C = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) its1 = s0.split(C, s0[C].iters[4], [8, 4, 4]) s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3], s0[C].iters[8]]) return dag, s0
def test_estimate_flop(): N = 512 A, B, C = matmul_auto_scheduler_test(N, N, N) dag = auto_scheduler.ComputeDAG([A, B, C]) assert abs(dag.flop_ct - 2 * N**3) < 0.5 D = topi.nn.relu(C) dag = auto_scheduler.ComputeDAG([A, B, D]) assert abs(dag.flop_ct - (2 * N**3 + N * N)) < 0.5 # should not count the comparison operations in padding E = topi.nn.pad(C, [1, 1]) dag = auto_scheduler.ComputeDAG([A, B, E]) assert abs(dag.flop_ct - 2 * N**3) < 0.5 F = te.compute((N, N), lambda i, j: E[i, j], name='F', attrs={"FLOP": 1234}) dag = auto_scheduler.ComputeDAG([A, B, F]) assert abs(dag.flop_ct - (2 * N**3 + 1234)) < 0.5
def test_apply_steps_with_layout_rewrite_corner_case(): A, B, C = matmul_auto_scheduler_test(1, 1, 1) dag = auto_scheduler.ComputeDAG([A, B, C]) s = dag.get_init_state() s.compute_root(C) i_j_fused = s.fuse(C, [s[C].iters[0], s[C].iters[1]]) s.parallel(C, i_j_fused) _, bufs = dag.apply_steps_from_state( s, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED )
def test_random_model(): if not tvm.runtime.enabled("llvm"): return N = 128 workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N)) dag = auto_scheduler.ComputeDAG(workload_key) target = tvm.target.create('llvm') task = auto_scheduler.SearchTask(dag, workload_key, target) model = auto_scheduler.RandomModel() model.update([], []) scores = model.predict(task, [dag.init_state, dag.init_state]) assert len(scores) == 2
def test_rfactor(): A, B, C = matmul_auto_scheduler_test(8, 8, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() ko, ki = s0.split(C, s0[C].iters[2], [16]) s1 = s0.copy() C_r = s1.rfactor(C, ko, 2) """ Placeholder: A, B for i (0,8) for j (0,8) for k_o (0,32) for k_i (0,16) C.rf = ... for ax0 (0,8) for ax1 (0,8) for k_o_v (0,32) C.repl = ... """ assert s1[C_r].iters[0].range.extent == 8 assert s1[C_r].iters[1].range.extent == 8 assert s1[C_r].iters[2].range.extent == 32 assert s1[C_r].iters[3].range.extent == 16 assert s1[C].iters[0].range.extent == 8 assert s1[C].iters[1].range.extent == 8 assert s1[C].iters[2].range.extent == 32 s2 = s0.copy() C_r = s2.rfactor(C, ki, 2) """ Placeholder: A, B for i (0,8) for j (0,8) for k_i (0,16) for k_o (0,32) C.rf = ... for ax0 (0,8) for ax1 (0,8) for k_i_v (0,16) C.repl = ... """ assert s2[C_r].iters[0].range.extent == 8 assert s2[C_r].iters[1].range.extent == 8 assert s2[C_r].iters[2].range.extent == 16 assert s2[C_r].iters[3].range.extent == 32 assert s2[C].iters[0].range.extent == 8 assert s2[C].iters[1].range.extent == 8 assert s2[C].iters[2].range.extent == 16
def get_sample_records(number): """Generate random a list of random MeasureInput and MeasureResult pairs""" N = 128 workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N)) dag = auto_scheduler.ComputeDAG(workload_key) target = tvm.target.create('llvm') task = auto_scheduler.SearchTask(dag, workload_key, target) policy = auto_scheduler.SketchPolicy(task, verbose=0) states = policy.sample_initial_population(number) inputs = [auto_scheduler.MeasureInput(task, s) for s in states] results = [auto_scheduler.MeasureResult([np.random.uniform(0.5, 1.0)], 0, "", 0.1, 0) for _ in range(len(inputs))] return task, dag, inputs, results
def test_mutate_tile_size(): """ The test case initializes evo search with a batch of "bad" states and check whether the search algorithm can find "good" states by mutating the "bad" states. This unit test has been tested with 1,000 runs with no failures, meaning that the failure rate is less than 0.1%. """ class MockCostModel(PythonBasedModel): """A mock cost model that rates 1 only for the states with tile_k=2.""" @staticmethod def is_good_state(state): for line in str(state).split("\n"): if line.find("k.1") != -1 and line.find("(0,2)") != -1: return True return False def predict(self, task, states): scores = [] for state in states: scores.append(1 if self.is_good_state(state) else 0) return scores workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4)) dag = auto_scheduler.ComputeDAG(workload_key) task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm")) policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0) states = policy.sample_initial_population()[:50] bad_states = [] for state in states: if not MockCostModel.is_good_state(state): bad_states.append(state) new_states = policy.evolutionary_search(bad_states, 50) found = False for state in new_states: if MockCostModel.is_good_state(state): found = True break assert found
def test_cpu_fusion(): def fusion_test(N, M): A = te.placeholder((N, M), name="A") B = te.compute((N, M), lambda i, j: A[i][j], name="B") C = te.compute((N, M), lambda i, j: B[i][j], name="C") return [A, B, C] dag = auto_scheduler.ComputeDAG(fusion_test(64, 32)) s = dag.get_init_state() s.compute_at(1, 2, s.stages[2].iters[1]) target = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(compute_dag=dag, workload_key="test", target=target) names = auto_scheduler.feature.get_per_store_feature_names() fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0] """ lowered IR: Placeholder: A for i (0,64) for j (0,32) for ii (1) for jj (1) B[...] = A[...] C[...] = B[...] """ # check reuse distance and reuse type after fusion found = False for stage_fea in fea: for i, (name, value) in enumerate(zip(names, stage_fea)): if "reuse_type.kSerialMultipleReadWrite" in name and value > 0.5: # reuse distance in #iter assert fequal(stage_fea[i + 2], 1.0) # reuse distance in bytes assert fequal(stage_fea[i + 3], math.log2(16 + 1)) found = True assert found
def test_record_pragma_storage_align_rfactor(): if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name="A") B = te.placeholder((512, 512), name="B") k = te.reduce_axis((0, 512), name="k") C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") dag = auto_scheduler.ComputeDAG([A, B, C]) s = dag.get_init_state() # Rfactor ko, _ = s.split(C, s[C].iters[2], [16]) s.rfactor(C, ko, 2) # Pragma s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64") # StorageAlign s.storage_align(C, s[C].iters[-1], 8, 4) record_common(dag, s)
def test_follow_split_follow_fused_split(): A, B, C = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() C_global = s0.cache_write(C, "global") its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) split_step0 = len(s0.transform_steps) - 1 for level in range(1, 6): tmp = s0.copy() tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) for i in range(0, level): assert tmp[C].iters[i].range.extent == \ tmp[C_global].iters[i].range.extent its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) split_step1 = len(s0.transform_steps) - 1 its = [] for i0, i1 in zip(its0, its1): its.append(i0) its.append(i1) s0.reorder(C, its) for i in range(0, 5): s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) for level in range(0, 4): tmp = s0.copy() tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, False) assert tmp[C].iters[level + 1].range.extent == \ tmp[C_global].iters[0].range.extent for level in range(0, 4): tmp = s0.copy() tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, True) assert tmp[C].iters[level + 1].range.extent == \ tmp[C_global].iters[1].range.extent
def test_evo_search(): """Test evolutionary search. Since we cannot mock random number generator, we mocked the cost model to manually guide the evo search. If evo search works as expected, it should find the target state after a sufficient number of iterations. This unit test has been tested with 1,000 runs with no failures, meaning that the failure rate is less than 0.1%. """ workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4)) dag = auto_scheduler.ComputeDAG(workload_key) task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm")) policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=MockCostModel(), verbose=0) states = policy.sample_initial_population(50) pruned_states = [] for state in states: found = False for line in str(state).split("\n"): # Remove all tile_k=2 states and expect evo search will fine them. if line.find("k.1") != -1 and line.find("(0,2)") != -1: found = True break if not found: pruned_states.append(state) new_states = policy.evolutionary_search(pruned_states, 50) found = False for state in new_states: for line in str(state).split("\n"): # Check if evo search found at least one state with tile_k=2. if line.find("k.1") != -1 and line.find("(0,2)") != -1: found = True break if found: break assert found
def test_dag_measure_local_builder_runner(): if not tvm.testing.device_enabled("llvm"): return A = te.placeholder((512, 512), name="A") B = te.placeholder((512, 512), name="B") k = te.reduce_axis((0, 512), name="k") C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") D = topi.nn.relu(C) E = topi.nn.relu(D) tensors = [A, B, E] dag = auto_scheduler.ComputeDAG(tensors) key = workload_registry.register_workload_tensors(dag.workload_key(), tensors) transfer_data = workload_registry.serialize_workload_registry_entry(key) f_data = pickle.dumps(transfer_data) f_new = pickle.loads(f_data) del workload_registry.WORKLOAD_FUNC_REGISTRY[key] workload_registry.deserialize_workload_registry_entry(f_new) target = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(compute_dag=dag, workload_key=key, target=target) for enable_cpu_cache_flush in [True, False]: minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state) local_builder = auto_scheduler.LocalBuilder() local_runner = auto_scheduler.LocalRunner( timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush) bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = local_runner.run([minp], bress) assert mress[0].error_no == 0
def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) data = te.placeholder((N, CI, H, W), name="Data") kernel_data = te.placeholder((CO, CI, KH, KW), name="Kernel_data") k0, k1 = te.compute( kernel_data.shape, lambda *i: (kernel_data(*i) + 1, kernel_data(*i) / 2), name="Kernel_split", ) kernel = te.compute(kernel_data.shape, lambda *i: k0(*i) + k1(*i), name="Kernel") conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) relu = topi.nn.relu(conv) add = topi.add(data, relu) dag = auto_scheduler.ComputeDAG([data, kernel_data, add]) s0 = dag.get_init_state() pad_temp = s0.stage_ops[1] kernel_split = s0.stage_ops[3] # 0: init state ori_its = s0[add].iters its = s0.split(add, s0[add].iters[0], [2]) s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) s0.compute_inline(relu) # 1: simple cache_write with compute_at conv_global = s0.cache_write(conv, "global") s0.compute_at(conv_global, conv, s0[conv].iters[3]) # 2: simple cache_read with compute_at kernel_global = s0.cache_read(kernel, "global", [conv_global]) s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) """ Placeholder: Data, Kernel_data for i0 (0,4) for i1 (0,512) for i2 (0,9) for i3 (0,9) pad_temp = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel = ... for nn (0,4) for ff (0,512) for yy (0,7) for xx (0,7) for nn_c (None) for ff_c (None) for yy_c (None) for xx_c (None) for rc (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) Kernel.global = ... for ry (None) for rx (None) compute.global = ... compute = ... for ax0.0 (0,2) for ax1 (0,512) for ax0.1 (0,2) for ax2 (0,7) for ax3 (0,7) T_add = ... """ s1 = dag.infer_bound_from_state(s0) assert s1[conv].iters[0].range.extent == 4 assert s1[conv].iters[1].range.extent == 512 assert s1[conv].iters[2].range.extent == 7 assert s1[conv].iters[3].range.extent == 7 assert s1[kernel_global].iters[0].range.extent == 1 assert s1[kernel_global].iters[1].range.extent == 1 assert s1[kernel_global].iters[2].range.extent == 3 assert s1[kernel_global].iters[3].range.extent == 3 assert s1[conv_global].iters[0].range.extent == 1 assert s1[conv_global].iters[1].range.extent == 1 assert s1[conv_global].iters[2].range.extent == 1 assert s1[conv_global].iters[3].range.extent == 1 assert s1[conv_global].iters[4].range.extent == 512 assert s1[conv_global].iters[5].range.extent == 3 assert s1[conv_global].iters[6].range.extent == 3 # 3: two level cache_read with compute_at # preparing for GPU's shared memory & local memory pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) # 4: cache_read with multi readers # This stage cannot be compute at to its consumer s0.cache_read(data, "global", [pad_temp, add]) """ Placeholder: Data, Kernel_data for ax0 (0,4) for ax1 (0,512) for ax2 (0,7) for ax3 (0,7) Data.global = ... for i0 (0,4) for i1 (0,512) for i2 (0,9) for i3 (0,9) pad_temp = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel = ... for nn (0,4) for ff (0,512) for yy (0,7) for xx (0,7) for nn_c (None) for ff_c (None) for yy_c (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global = ... for xx_c (None) for rc (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) Kernel.global = ... for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global.shared = ... for ry (None) for rx (None) compute.global = ... compute = ... for ax0.0 (0,2) for ax1 (0,512) for ax0.1 (0,2) for ax2 (0,7) for ax3 (0,7) T_add = ... """ s1 = dag.infer_bound_from_state(s0) assert s1[conv].iters[0].range.extent == 4 assert s1[conv].iters[1].range.extent == 512 assert s1[conv].iters[2].range.extent == 7 assert s1[conv].iters[3].range.extent == 7 assert s1[kernel_global].iters[0].range.extent == 1 assert s1[kernel_global].iters[1].range.extent == 1 assert s1[kernel_global].iters[2].range.extent == 3 assert s1[kernel_global].iters[3].range.extent == 3 assert s1[conv_global].iters[0].range.extent == 1 assert s1[conv_global].iters[1].range.extent == 1 assert s1[conv_global].iters[2].range.extent == 1 assert s1[conv_global].iters[3].range.extent == 1 assert s1[conv_global].iters[4].range.extent == 512 assert s1[conv_global].iters[5].range.extent == 3 assert s1[conv_global].iters[6].range.extent == 3 assert s1[pad_temp_global].iters[0].range.extent == 1 assert s1[pad_temp_global].iters[1].range.extent == 512 assert s1[pad_temp_global].iters[2].range.extent == 3 assert s1[pad_temp_global].iters[3].range.extent == 3 assert s1[pad_temp_shared].iters[0].range.extent == 1 assert s1[pad_temp_shared].iters[1].range.extent == 1 assert s1[pad_temp_shared].iters[2].range.extent == 3 assert s1[pad_temp_shared].iters[3].range.extent == 3 # 5: cache_write with multi outputs # TVM's cache_write actually has a bug with this case: # # After schedule.cache_write, TVM generate one new stage: # From: kernel_data -> kernel_split -> kernel # To: kernel_data -> kernel_split_global -> kernel_split -> kernel # # But with topo sort analyse, we get: # // kernel_data -> kernel_split_global -> kernel_split -> kernel # \ / # ----------------> kernel_split ----------------> # # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later. kernel_split_global = s0.cache_write(kernel_split, "global") """ Placeholder: Data, Kernel_data for ax0 (0,4) for ax1 (0,512) for ax2 (0,7) for ax3 (0,7) Data.global = ... for i0 (0,4) for i1 (0,512) for i2 (0,9) for i3 (0,9) pad_temp = ... for i0_c (0,512) for i1_c (0,512) for i2_c (0,3) for i3_c (0,3) Kernel_split.global = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... (******* Bug here, there should not be two kernel_split stage *******) for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... (******* Bug here, there should not be two kernel_split stage *******) for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel = ... for nn (0,4) for ff (0,512) for yy (0,7) for xx (0,7) for nn_c (None) for ff_c (None) for yy_c (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global = ... for xx_c (None) for rc (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) Kernel.global = ... for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global.shared = ... for ry (None) for rx (None) compute.global = ... compute = ... for ax0.0 (0,2) for ax1 (0,512) for ax0.1 (0,2) for ax2 (0,7) for ax3 (0,7) T_add = ... """ assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters) for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): assert it0.range == it1.range
def test_gpu_feature(): # Use records to build a complicated GPU program json_records = "\n".join(( """{"i": [["[\\"matmul_auto_scheduler_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""", )) # load states with tempfile.NamedTemporaryFile(mode="w") as f: f.write(json_records) f.flush() inputs, results = auto_scheduler.RecordReader(f.name).read_lines() inp = inputs[0] dag = auto_scheduler.ComputeDAG(inp.task.workload_key) task = auto_scheduler.SearchTask( dag, inp.task.workload_key, inp.task.target, None, auto_scheduler.HardwareParams(100000, 16, 64, 1 << 30, 1 << 30, 1 << 30, 1 << 30, 1 << 30), ) state = dag.infer_bound_from_state(inputs[0].state) fea = auto_scheduler.feature.get_per_store_features_from_states( [state], task)[0] names = auto_scheduler.feature.get_per_store_feature_names() # build feature dict fea_dicts = [] for i in range(len(fea)): tmp_dict = {} for j in range(len(names)): tmp_dict[names[j]] = fea[i][j] fea_dicts.append(tmp_dict) """ lowered IR: Placeholder: A, B blockIdx.x [email protected]@ (0,8) vthread [email protected]@ (0,4) threadIdx.x [email protected]@ (0,16) C.local auto_unroll: 1024 for k.0 (0,256) for ax0@[email protected] (0,8) threadIdx.x ax0@[email protected] (0,16) B.shared = ... for ax0@[email protected] (0,64) threadIdx.x ax0@[email protected] (0,16) A.shared = ... for i_c.3 (0,32) for k.2 (0,2) for j_c.4 (0,16) C.local = ... for i.3 (0,32) for j.3 (0,16) C = ... """ # check gpu-related features assert fequal(fea_dicts[0]["blockIdx_x_len"], math.log2(8 + 1)) assert fequal(fea_dicts[0]["vthread_len"], math.log2(4 + 1)) assert fequal(fea_dicts[1]["threadIdx_x_len"], math.log2(16 + 1)) assert fequal(fea_dicts[0]["threadIdx_y_len"], math.log2(1 + 1)) assert fequal(fea_dicts[2]["blockIdx_z_len"], math.log2(1 + 1)) assert fequal(fea_dicts[0]["is_gpu"], 1.0)
def test_stage_order(): """Test if the stage order is preserved when recovering a DAG.""" N = 512 A, B, C, D, E = parallel_matmul_auto_scheduler_test(N) sch = te.create_schedule([D.op, E.op]) (D_local, ) = sch.cache_write([D], "local") (E_local, ) = sch.cache_write([E], "local") sch.cache_read(A, "shared", [D_local]) sch.cache_read(B, "shared", [D_local]) sch.cache_read(A, "shared", [E_local]) sch.cache_read(C, "shared", [E_local]) dag = auto_scheduler.ComputeDAG(sch) stage_ops_1 = dag.get_init_state().stage_ops # 3 placeholder, 4 x.shared, 2 {D,E}.local, 2 {D,E} compute assert len(stage_ops_1) == 11 # Cache read stage should follow the source stage for idx, op in enumerate(stage_ops_1): if op.name == "A": assert (stage_ops_1[idx + 1].name == "A.d.shared" and stage_ops_1[idx + 2].name == "A.shared") elif op.name in ["B", "C"]: assert stage_ops_1[idx + 1].name == "%s.shared" % op.name # Serialize and deserialize the ComputeDAG constructed by a schedule. loaded_dag = pickle.loads(pickle.dumps(dag)) assert str(loaded_dag.get_init_state()) == str(dag.get_init_state()) assert len(loaded_dag.get_init_state().stage_ops) == len( dag.get_init_state().stage_ops) # Apply the same schedule to Ansor state and it should have the same stage order dag = auto_scheduler.ComputeDAG([A, B, C, D, E]) state = dag.get_init_state() D_local = state.cache_write(D, "local") E_local = state.cache_write(E, "local") state.cache_read(A, "shared", [D_local]) state.cache_read(B, "shared", [D_local]) state.cache_read(A, "shared", [E_local]) state.cache_read(C, "shared", [E_local]) stage_ops_2 = state.stage_ops assert len(stage_ops_1) == len(stage_ops_2) # Cache read stage should follow the source stage for op1, op2 in zip(stage_ops_1, stage_ops_2): assert op1.name == op2.name # Serialize and deserialize the ComputeDAG constructed by a list of tensor ops. loaded_dag = pickle.loads(pickle.dumps(dag)) assert str(loaded_dag.get_init_state()) == str(dag.get_init_state()) assert len(loaded_dag.get_init_state().stage_ops) == len( dag.get_init_state().stage_ops) # Serialize and deserialize the search task. task = auto_scheduler.SearchTask( dag, json.dumps(("test-key", )), tvm.target.Target("llvm"), hardware_params=auto_scheduler.HardwareParams(100000, 16, 64, 0, 0, 0, 0, 0), ) task2 = pickle.loads(pickle.dumps(task)) assert "test-key" in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state()) assert len(task.dag.get_init_state().stage_ops) == len( task2.dag.get_init_state().stage_ops) assert task.workload_key == task2.workload_key assert str(task.target) == str(task2.target) assert task.hardware_params.num_cores == task2.hardware_params.num_cores assert task.hardware_params.vector_unit_bytes == task2.hardware_params.vector_unit_bytes assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes
def test_layout_rewrite_correctness(): N = 128 target = "llvm" workload = matmul_auto_scheduler_test workload_key = auto_scheduler.make_workload_key(workload, (N, N, N)) dag = auto_scheduler.ComputeDAG(workload_key) target = tvm.target.create(target) task = auto_scheduler.SearchTask(dag, workload_key, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner='local', verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, workload_key, target) s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False) np_args = [ np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs ] np_args_ref = [np.array(x) for x in np_args] weight = np_args_ref[1] # infer shape for the rewritten layout if len(weight.shape) >= 6: # For cpu tile structure SSRSRS base = len(weight.shape) - 6 red_dim = weight.shape[2 + base] * weight.shape[4 + base] out_dim = weight.shape[3 + base] * weight.shape[5 + base] for i in range(base + 2): out_dim *= weight.shape[i] new_order = [ 2 + base, 4 + base, ] + list(range(base + 2)) + [ 3 + base, 5 + base, ] np_args_ref[1] = np_args_ref[1].transpose(new_order) np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim)) func = tvm.build(s, bufs, target=inp.task.target, target_host=inp.task.target_host) func_ref = tvm.build(s_ref, bufs_ref, target='llvm') ctx = tvm.context(str(inp.task.target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() np.testing.assert_allclose(np_args[0], np_args_ref[0]) np.testing.assert_allclose(np_args[2], np_args_ref[2])