def test_correctness_layout_rewrite_insert_transform_stage(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) dag = task.compute_dag with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner=measure_ctx.runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state( inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG. InsertTransformStage) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [ np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs ] func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) ctx = tvm.context(str(target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3) del measure_ctx
def tune(self, n_trial, **kwargs): global GLOBAL_TUNER GLOBAL_TUNER = self auto_scheduler.auto_schedule( self.auto_task, tuning_options=auto_scheduler.TuningOptions( num_measure_trials=n_trial, runner=self.measure_ctx.runner, measure_callbacks=[]))
def tune(self, n_trial, **kwargs): global GLOBAL_TUNER GLOBAL_TUNER = self try: auto_scheduler.auto_schedule( self.auto_task, tuning_options=auto_scheduler.TuningOptions( num_measure_trials=n_trial, num_measures_per_round=self.task.n_parallel, runner=self.measure_ctx.runner, measure_callbacks=[])) except: import traceback traceback.print_exc() exit(1)
def search_common(workload=matmul_auto_scheduler_test, target="llvm", search_policy='empty', seed=random.randint(1, 1 << 30), runner='local', cost_model=auto_scheduler.RandomModel(), num_measure_trials=2, init_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) N = 128 workload_key = auto_scheduler.make_workload_key(workload, (N, N, N)) dag = auto_scheduler.ComputeDAG(workload_key) target = tvm.target.create(target) task = auto_scheduler.SearchTask(dag, workload_key, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name init_search_callbacks = init_search_callbacks or [] init_search_callbacks.append( auto_scheduler.PreloadMeasuredStates(log_file)) if search_policy == 'empty': search_policy = auto_scheduler.EmptyPolicy(task) elif search_policy == 'sketch': search_policy = auto_scheduler.SketchPolicy( task, init_search_callbacks=init_search_callbacks) tuning_options = auto_scheduler.TuningOptions( num_measure_trials=num_measure_trials, runner=runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]) sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, res = auto_scheduler.load_best(log_file, workload_key, target) print("==== Python Code ====") print(dag.print_python_code_from_state(inp.state)) try: print("==== Lowered Stmt ====") print(tvm.lower(sch, args, simple_mode=True)) mod = tvm.build(sch, args, target) ctx = tvm.context(str(target), 0) dtype = dag.tensors[0].dtype a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) mod(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5) print("==== Verification passed ====") except Exception: raise Exception("Error encountered with seed: %d" % (seed)) print()
def resume_search(task, log_file): cost_model = auto_scheduler.XGBModel() cost_model.update_from_file(log_file) search_policy = auto_scheduler.SketchPolicy( task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]) tune_option = auto_scheduler.TuningOptions( num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]) sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)
log_file = "conv2d.json" measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) tune_option = auto_scheduler.TuningOptions( num_measure_trials=10, # change this to 1000 to achieve the best performance runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) ###################################################################### # Run the search # ^^^^^^^^^^^^^^ # Now we get all inputs ready. Pretty simple, isn't it? # We can kick off the search and let the auto-scheduler do its magic. # After some measurement trials, it will return the best schedule it found. sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option) # Kill the process for measurement del measure_ctx ###################################################################### # We can lower the schedule to see the IR after auto-scheduling. # The auto-scheduler correctly performs optimizations including multi-level tiling, # cooperative fetching, unrolling and operator fusion. print(tvm.lower(sch, args, simple_mode=True)) ###################################################################### # Check correctness and evaluate performance # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # We build the binary and check its correctness and performance.
func_name = "dense_fwd_n" + str(N) + "_ci" + str(CI) + "_co" + str( CO) + "_" + str(device) log_file = func_name + ".json" task = tvm.auto_scheduler.create_task(dense_fwd, (N, CI, CO, "float32"), target) ### search measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) tune_option = auto_scheduler.TuningOptions( num_measure_trials=num_search_trails, runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], verbose=2, ) sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option) del measure_ctx ### load history # inp, res = auto_scheduler.load_best(log_file, task.workload_key) # sch, args = task.compute_dag.apply_steps_from_state(inp.state) # build func ctx = tvm.gpu() func = tvm.build(sch, args, target, name=func_name) # save result obj_fname = func_name + ".o" ptx_fname = func_name + ".ptx" func.save(obj_fname) func.imported_modules[0].save(ptx_fname)
def test_layout_rewrite_correctness(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) dag = task.compute_dag with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner="local", verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False) np_args = [ np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs ] np_args_ref = [np.array(x) for x in np_args] weight = np_args_ref[1] # infer shape for the rewritten layout if len(weight.shape) >= 6: # For cpu tile structure SSRSRS base = len(weight.shape) - 6 red_dim = weight.shape[2 + base] * weight.shape[4 + base] out_dim = weight.shape[3 + base] * weight.shape[5 + base] for i in range(base + 2): out_dim *= weight.shape[i] new_order = ([ 2 + base, 4 + base, ] + list(range(base + 2)) + [ 3 + base, 5 + base, ]) np_args_ref[1] = np_args_ref[1].transpose(new_order) np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim)) func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) ctx = tvm.context(str(target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() np.testing.assert_allclose(np_args[0], np_args_ref[0]) np.testing.assert_allclose(np_args[2], np_args_ref[2])
def search_common( workload=matmul_auto_scheduler_test, target="llvm", search_policy="sketch", seed=0, runner="local", num_measure_trials=100, cost_model=auto_scheduler.RandomModel(), init_search_callbacks=None, ): print("Test search policy '%s' for '%s'" % (search_policy, target)) random.seed(seed) N = 128 target = tvm.target.Target(target) task = auto_scheduler.create_task(workload, (N, N, N), target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name init_search_callbacks = init_search_callbacks or [] init_search_callbacks.append( auto_scheduler.PreloadMeasuredStates(log_file)) if search_policy == "empty": search_policy = auto_scheduler.EmptyPolicy(task) elif search_policy == "sketch": search_policy = auto_scheduler.SketchPolicy( task, program_cost_model=cost_model, init_search_callbacks=init_search_callbacks) else: raise ValueError("Invalid policy: " + search_policy) tuning_options = auto_scheduler.TuningOptions( num_measure_trials=num_measure_trials, num_measures_per_round=2, early_stopping=1, runner=runner, verbose=2, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, res = auto_scheduler.load_best(log_file, task.workload_key, target) print("==== Python Code ====") print(task.compute_dag.print_python_code_from_state(inp.state)) try: print("==== Lowered Stmt ====") print(tvm.lower(sch, args, simple_mode=True)) mod = tvm.build(sch, args, target) ctx = tvm.context(str(target), 0) dtype = task.compute_dag.tensors[0].dtype a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) mod(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5) print("==== Verification passed ====") except Exception: raise Exception("Error encountered with seed: %d" % (seed)) print()