def test_measure_special_inputs_map_by_name_rpc_runner(): @auto_scheduler.register_workload def foo(): X = te.placeholder(shape=[10], dtype="int32") Index = te.placeholder(shape=[1], dtype="int32", name="Index") Y = te.compute((1, ), lambda i: X[Index[i]]) return [X, Index, Y] # This workload cannot use random input for the `Index` input task = auto_scheduler.SearchTask( func=foo, target="llvm", task_inputs={ "Index": tvm.nd.array(np.array([5], dtype="int32")), }, ) for enable_cpu_cache_flush in [True, False]: minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state) local_builder = auto_scheduler.LocalBuilder() measure_ctx = auto_scheduler.LocalRPCMeasureContext( timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush) rpc_runner = measure_ctx.runner bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = rpc_runner.run([minp], bress) assert mress[0].error_no == 0
def test_tuning_cuda(): auto_scheduler.enable_relay_integration() # Extract tasks mod, params = get_network("mlp") target = tvm.target.Target("cuda") tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) objective = lambda costs: sum(c * w for c, w in zip(costs, task_weights)) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name # Tuning measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=100) tuner = auto_scheduler.TaskScheduler(tasks, objective) tune_option = auto_scheduler.TuningOptions( num_measure_trials=2, num_measures_per_round=1, runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tuner.tune(tune_option, search_policy="sketch.random") del measure_ctx # Compile with the history best with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) # Todo(merrymercy): compile without any history to test the fallback mechanism auto_scheduler.enable_relay_integration(False)
def auto_scheduler_tune(network, target, input_name, log_file): if os.path.exists(log_file): os.remove(log_file) mod, net_params, input_shape, output_shape = get_network(network) if network not in ["bert"]: # convert to NHWC layout desired_layouts = {'nn.conv2d': ['NHWC', 'default']} seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), relay.transform.ConvertLayout(desired_layouts)]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) if "cpu" in target.keys: tuning_opt = auto_scheduler.TuningOptions( num_measure_trials=20000, # change this to 20000 to achieve the best performance runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) else: measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10) tuning_opt = auto_scheduler.TuningOptions( num_measure_trials=20000, # change this to 20000 to achieve the best performance runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], net_params, target) tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tuner.tune(tuning_opt)
def tune_network(network, target): # Extract tasks mod, params = get_network(network) target = tvm.target.Target(target) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name # Tuning measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60) tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( num_measure_trials=100, num_measures_per_round=2, early_stopping=1, runner=measure_ctx.runner, builder=auto_scheduler.LocalBuilder(timeout=60), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tuner.tune(tune_option, search_policy="sketch.random") del measure_ctx # Compile with the history best with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True}): lib = relay.build(mod, target=target, params=params)
def local_auto_scheduler(self, repeat=1, min_repeat_ms=300, timeout=10, num_measure_trials=200): # extract tasks tasks, task_weights = auto_scheduler.extract_tasks( self.mod["main"], self.params, self.target) for idx, task in enumerate(tasks): logger.debug("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key)) logger.debug(task.compute_dag) # generate tuner tuner = auto_scheduler.TaskScheduler(tasks, task_weights) logging.info("Begin tuning...") measure_ctx = auto_scheduler.LocalRPCMeasureContext( repeat=repeat, min_repeat_ms=min_repeat_ms, timeout=timeout) tune_option = auto_scheduler.TuningOptions( num_measure_trials=num_measure_trials, runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(self.log_file)], ) tuner.tune(tune_option) # update self.lib with auto_scheduler.ApplyHistoryBest(self.log_file): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True}): self._lib = relay.build(self.mod, target=self.target, params=self.params) logger.info(f"load optimized library from {self.log_file}")
def test_correctness_layout_rewrite_insert_transform_stage(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) dag = task.compute_dag with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner=measure_ctx.runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state( inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG. InsertTransformStage) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [ np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs ] func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) ctx = tvm.context(str(target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3) del measure_ctx
def tune_network(network, target): # Extract tasks mod, params = get_network(network) target = tvm.target.Target(target) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name # Tuning measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60) tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( num_measure_trials=100, num_measures_per_round=2, early_stopping=1, runner=measure_ctx.runner, builder=auto_scheduler.LocalBuilder(timeout=60), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tuner.tune(tune_option, search_policy="sketch.random") del measure_ctx # Compile with the history best with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True}): lib = relay.build(mod, target=target, params=params) # Compile without auto-scheduler and any other optimization for correctness check with tvm.transform.PassContext(opt_level=0): lib2 = relay.build(mod, target=target, params=params) # Check the correctness def get_output(data, lib): ctx = tvm.gpu() module = graph_runtime.GraphModule(lib["default"](ctx)) module.set_input("data", data) module.run() return module.get_output(0).asnumpy() np.random.seed(0) if network == "mlp": data = np.random.uniform(size=(1, 32)) elif network == "winograd-test": data = np.random.uniform(size=(1, 23, 40, 32)) else: raise ValueError("Unknown network: " + network) actual_output = get_output(data, lib) expected_output = get_output(data, lib2) tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4)
def test_sketch_search_policy_cuda_rpc_runner(): measure_ctx = auto_scheduler.LocalRPCMeasureContext() # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = PropagatingThread(target=search_common, kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda', 'runner': measure_ctx.runner}) t.start() t.join()
def test_sketch_search_policy_cuda_xgbmodel_rpc_runner(): if not tvm.runtime.enabled("cuda"): return measure_ctx = auto_scheduler.LocalRPCMeasureContext() # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = PropagatingThread(target=search_common, kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda', 'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()}) t.start() t.join()
def __init__(self, task, **kwargs): self.task = task self.measure_ctx = auto_scheduler.LocalRPCMeasureContext( min_repeat_ms=300) @auto_scheduler.register_workload def auto_template(): _, arg_bufs = task.func() return arg_bufs self.auto_task = auto_scheduler.create_task(auto_template, (), task.target)
def run_tuning(): print("Begin tuning...") measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10) tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( num_measure_trials=200, # change this to 20000 to achieve the best performance runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tuner.tune(tune_option)
def test_sketch_search_policy_cuda_rpc_runner(): measure_ctx = auto_scheduler.LocalRPCMeasureContext() # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = PropagatingThread( target=search_common, kwargs={ "target": "cuda", "runner": measure_ctx.runner, }, ) t.start() t.join()
def test_sketch_search_policy_zero_rank(): measure_ctx = auto_scheduler.LocalRPCMeasureContext() for target in ["llvm", "cuda"]: task = auto_scheduler.SearchTask( func=zero_rank_compute_auto_scheduler_test, args=(10, ), target=target) search_common(task, runner=measure_ctx.runner) task = auto_scheduler.SearchTask( func=zero_rank_reduce_auto_scheduler_test, args=(10, ), target=target) search_common(task, runner=measure_ctx.runner)
def test_task_scheduler_round_robin(): tasks = [] for n in [2, 4, 8]: tasks.append( auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm")) def objective_func(costs): return sum(costs) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name num_trials_per_task = 2 # Tune all tasks measure_ctx = auto_scheduler.LocalRPCMeasureContext() tune_option = auto_scheduler.TuningOptions( num_measure_trials=num_trials_per_task * len(tasks), runner=measure_ctx.runner, num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func, strategy="round-robin") task_scheduler.tune(tune_option, search_policy="sketch.random") # Check the result of round robin counters = {} for task in tasks: counters[task.workload_key] = 0 for inp, res in auto_scheduler.load_records(log_file): counters[inp.task.workload_key] += 1 for task in tasks: assert counters[task.workload_key] == num_trials_per_task # test continuous tuning (restoring the status) task_scheduler = auto_scheduler.TaskScheduler(tasks, objective_func, strategy="round-robin", load_log_file=log_file) tune_option = auto_scheduler.TuningOptions( num_measure_trials=len(tasks), num_measures_per_round=1, ) task_scheduler.tune(tune_option, search_policy="sketch.random") del measure_ctx
def test_task_scheduler_gradient(): tasks = [] for n in [2, 4]: tasks.append( auto_scheduler.SearchTask( func=matmul_auto_scheduler_test, args=(n, n, n), target="llvm" ) ) def objective_func(costs): return costs[0] with tempfile.NamedTemporaryFile() as fp: log_file = fp.name n_trials = 5 # Tune all tasks measure_ctx = auto_scheduler.LocalRPCMeasureContext() tune_option = auto_scheduler.TuningOptions( num_measure_trials=n_trials, runner=measure_ctx.runner, num_measures_per_round=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) task_scheduler = auto_scheduler.TaskScheduler( tasks, objective_func=objective_func, callbacks=[] ) # Forcely rewrite the initial values. # This can make this test more stable on the slow CI machines task_scheduler.best_costs = np.array([1e2, 1e-8]) task_scheduler.tune(tune_option, search_policy="sketch.random") # Check the allocation results counters = {} for task in tasks: counters[task.workload_key] = 0 for inp, _ in auto_scheduler.load_records(log_file): counters[inp.task.workload_key] += 1 assert counters[tasks[0].workload_key] == n_trials - 1 assert counters[tasks[1].workload_key] == 1 del measure_ctx
def resume_search(task, log_file): print("Resume search:") cost_model = auto_scheduler.XGBModel() cost_model.update_from_file(log_file) search_policy = auto_scheduler.SketchPolicy( task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)] ) measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) tune_option = auto_scheduler.TuningOptions( num_measure_trials=5, runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) task.tune(tune_option, search_policy=search_policy) # Kill the measurement process del measure_ctx
def test_measure_local_builder_rpc_runner(): if not tvm.runtime.enabled("llvm"): return dag, s0 = get_tiled_matmul() tgt = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", tgt) minp = auto_scheduler.MeasureInput(task, s0) local_builder = auto_scheduler.LocalBuilder() measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60) rpc_runner = measure_ctx.runner bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = rpc_runner.run([minp], bress) assert mress[0].error_no == 0
def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False): if not tvm.testing.device_enabled("llvm"): return dag, s0 = get_tiled_matmul() tgt = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(dag, "test", tgt) minp = auto_scheduler.MeasureInput(task, s0) local_builder = auto_scheduler.LocalBuilder() measure_ctx = auto_scheduler.LocalRPCMeasureContext( timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush) rpc_runner = measure_ctx.runner bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = rpc_runner.run([minp], bress) assert mress[0].error_no == 0
def test_sketch_search_policy_cuda_xgbmodel_rpc_runner(): if not tvm.runtime.enabled("cuda"): return measure_ctx = auto_scheduler.LocalRPCMeasureContext() # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = PropagatingThread( target=search_common, kwargs={ "seed": 944563397, "search_policy": "sketch", "target": "cuda", "runner": measure_ctx.runner, "cost_model": auto_scheduler.XGBModel(), }, ) t.start() t.join()
def test_measure_local_builder_rpc_runner(): if not tvm.testing.device_enabled("llvm"): return task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm") for enable_cpu_cache_flush in [True, False]: minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state) local_builder = auto_scheduler.LocalBuilder() measure_ctx = auto_scheduler.LocalRPCMeasureContext( timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush ) rpc_runner = measure_ctx.runner bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = rpc_runner.run([minp], bress) assert mress[0].error_no == 0 del measure_ctx
def auto_scheduler_tune(network, batch_size, dtype, target, log_file): os.makedirs(os.path.dirname(log_file), exist_ok=True) #if os.path.exists(log_file): # os.remove(log_file) layout = "NHWC" mod, params, input_name, input_shape, output_shape = get_network( network, batch_size, dtype, layout) n_trials = network_to_n_trials[(network, batch_size, dtype, str(target.kind))] if "cpu" in target.keys: tuning_opt = auto_scheduler.TuningOptions( num_measure_trials=n_trials, runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) else: min_repeat_ms = 450 if network in ["bert"] else 300 measure_ctx = auto_scheduler.LocalRPCMeasureContext( repeat=1, min_repeat_ms=min_repeat_ms, timeout=10) tuning_opt = auto_scheduler.TuningOptions( num_measure_trials=n_trials, runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) print(log_file) update_file(log_file, tasks) return for idx, task in enumerate(tasks): print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key)) print(task.compute_dag) tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tuner.tune(tuning_opt)
def tune_network(network, target): auto_scheduler.enable_relay_integration() # Extract tasks mod, params = get_network(network) target = tvm.target.Target(target) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name # Tuning measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60) tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( num_measure_trials=100, num_measures_per_round=2, early_stopping=1, runner=measure_ctx.runner, builder=auto_scheduler.LocalBuilder(timeout=60), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tuner.tune(tune_option, search_policy="sketch.random") del measure_ctx # Compile with the history best with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) # Todo(merrymercy): when the cpu backend is upstreamed, do the following things: # 1. compile without history to test the fallback mechanism # 2. check the correctness of layout rewrite / winograd pre-transform auto_scheduler.enable_relay_integration(False)
def __init__(self, task, **kwargs): self.task = task self.measure_ctx = auto_scheduler.LocalRPCMeasureContext( min_repeat_ms=300) self.auto_task = create_auto_task(task.target)
def __init__(self, task, **kwargs): self.task = task self.measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) self.auto_task = create_auto_task(task.target) assert backend in ('c-cuda', 'c-rocm'), "Ansor in Antares is enabled for CUDA/ROCm only."
def tune_network(network, target): # Extract tasks mod, params = get_network(network) target = tvm.target.Target(target) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name # Tuning measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60, device=0) tuner = auto_scheduler.TaskScheduler(tasks, task_weights, callbacks=[]) tune_option = auto_scheduler.TuningOptions( num_measure_trials=100, num_measures_per_round=2, early_stopping=1, runner=measure_ctx.runner, builder=auto_scheduler.LocalBuilder(timeout=60), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) tuner.tune(tune_option, search_policy="sketch.random") del measure_ctx # Compile with the history best with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True} ): lib = relay.build(mod, target=target, params=params) # Also test that multiple log files can be loaded. with auto_scheduler.ApplyHistoryBest([log_file, log_file]) as best: assert isinstance( best, auto_scheduler.dispatcher.ApplyHistoryBest ), "Unable to load multiple log files jointly." # Confirm iterables can be directly loaded. loaded_recs = auto_scheduler.dispatcher.load_records(log_file) with auto_scheduler.ApplyHistoryBest(iter(loaded_recs)) as best: assert isinstance( best, auto_scheduler.dispatcher.ApplyHistoryBest ), "Unable to ingest logs from an interator." # Sample a schedule when missing with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True} ): lib2 = relay.build(mod, target=target, params=params) # Compile without auto-scheduler and any other optimization for correctness check with tvm.transform.PassContext(opt_level=0): ref_lib = relay.build(mod, target=target, params=params) # Check the correctness def get_output(data, lib): dev = tvm.cuda() module = graph_executor.GraphModule(lib["default"](dev)) module.set_input("data", data) module.run() return module.get_output(0).numpy() np.random.seed(0) if network == "mlp": data = np.random.uniform(size=(1, 32)) elif network == "winograd-test": data = np.random.uniform(size=(1, 23, 40, 32)) else: raise ValueError("Unknown network: " + network) actual_output1 = get_output(data, lib) actual_output2 = get_output(data, lib2) expected_output = get_output(data, ref_lib) tvm.testing.assert_allclose(actual_output1, expected_output, rtol=1e-4, atol=1e-4) tvm.testing.assert_allclose(actual_output2, expected_output, rtol=1e-4, atol=1e-4)
def test_sketch_search_policy_cuda_rpc_runner(): measure_ctx = auto_scheduler.LocalRPCMeasureContext() search_common(target="cuda", runner=measure_ctx.runner)
def test_sketch_search_policy_cuda_xgbmodel_rpc_runner(): measure_ctx = auto_scheduler.LocalRPCMeasureContext() search_common(target="cuda", runner=measure_ctx.runner, cost_model=auto_scheduler.XGBModel())
def test_correctness_layout_rewrite_rewrite_for_preTransformed(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) dag = task.compute_dag with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner=measure_ctx.runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state( inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG. RewriteForPreTransformed) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [ np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs ] np_args_ref = [np.array(x) for x in np_args] weight = np_args_ref[1] # infer shape for the rewritten layout if len(weight.shape) >= 6: # For cpu tile structure SSRSRS base = len(weight.shape) - 6 red_dim = weight.shape[2 + base] * weight.shape[4 + base] out_dim = weight.shape[3 + base] * weight.shape[5 + base] for i in range(base + 2): out_dim *= weight.shape[i] new_order = ([ 2 + base, 4 + base, ] + list(range(base + 2)) + [ 3 + base, 5 + base, ]) np_args_ref[1] = np_args_ref[1].transpose(new_order) np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim)) func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) ctx = tvm.context(str(target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3) del measure_ctx
# provide isolation. It can protect the master process from GPU crashes # during measurement and avoid other runtime conflicts. # * :code:`min_repeat_ms` defines the minimum duration of one "repeat" in every measurement. # This can warmup the GPU, which is necessary to get accurate measurement results. # Typically, we recommend a value > 300 ms. # * :code:`num_measure_trials` is the number of measurement trials we can use during the search. # We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a # good value for the search to converge. You can do more trials according to your time budget. # * In addition, we use :code:`RecordToFile` to dump measurement records into a file `conv2d.json`. # The measurement records can be used to query the history best, resume the search, # and do more analyses later. # * see :any:`auto_scheduler.TuningOptions`, # :any:`auto_scheduler.LocalRPCMeasureContext` for more parameters. log_file = "conv2d.json" measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) tune_option = auto_scheduler.TuningOptions( num_measure_trials=10, # change this to 1000 to achieve the best performance runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) ###################################################################### # Run the search # ^^^^^^^^^^^^^^ # Now we get all inputs ready. Pretty simple, isn't it? # We can kick off the search and let the auto-scheduler do its magic. # After some measurement trials, it will return the best schedule it found. sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)
filter_size) input_info_list.append((data_shape, input_info_list)) N, H, W, CO, CI, KH, KW, strides, padding = batch_size, in_height, in_width, out_channel, in_channel, filter_size, filter_size, ( 1, 1), (1, 1) task = auto_scheduler.SearchTask(func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target) create_tasks.append(task) for task in create_tasks: print("---------------") print(task.compute_dag) measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10) tune_option = auto_scheduler.TuningOptions( num_measure_trials= 1000, # change this to 1000 to achieve the best performance runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], verbose=2, ) print("we have {} tasks to tune".format(str(len(create_tasks)))) index = 1 for task in create_tasks[:5]: index += 1 print("current tuning task {} .................".format(index)) task.tune(tune_option)