def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params) expected_task_names = [ "fused_" + s for s in [ "nn_max_pool2d", "nn_adaptive_avg_pool2d", "nn_dense_add", "nn_conv2d_add", "nn_conv2d_add_1", "nn_conv2d_add_2", "nn_conv2d_add_add_nn_relu", "nn_conv2d_add_add_nn_relu_1", "nn_conv2d_add_nn_relu", "nn_conv2d_add_nn_relu_1", "nn_conv2d_add_nn_relu_2", "nn_conv2d_add_nn_relu_3", "nn_conv2d_add_nn_relu_4", "nn_conv2d_add_nn_relu_5", "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", # The two tasks below are purely spatial and are ruled out by AutoScheduler "layout_transform", "layout_transform_reshape_squeeze", ] ] assert len(extracted_tasks) == len(expected_task_names) for t in extracted_tasks: assert t.task_name in expected_task_names, t.task_name
def extract_task_qbert(): mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128) target = "llvm -mcpu=cascadelake" extracted_tasks = ms.extract_task_from_relay(mod, target, params) tune_tasks = list( filter( lambda task: "dense" in task.task_name or "batch_matmul" in task. task_name, extracted_tasks, )) # three int8 dense, two int8 bmm, and one fp32 dense assert len(tune_tasks) == 6 for task in tune_tasks: relay_func = list(task.mod.functions.values())[0] out_type = relay_func.body.checked_type if out_type.dtype == "float32": continue mod = ms.default_config.mod(task.dispatched[0]) sch = tvm.tir.Schedule(mod) block = sch.get_block("compute") annotations = sch.get(block).annotations assert "schedule_rule" in annotations assert "vnni" in annotations["schedule_rule"]
def extract_from_relay( mod: IRModule, target: Target, params: Optional[Dict[str, NDArray]], name: str, input_shape: List[int], *, cache_dir: Optional[str] = None, opt_level: int = 3, pass_config: Optional[Dict[str, Any]] = None, disabled_pass: Optional[List[str]] = None, ) -> List[ExtractedTask]: """Extract the tasks from a network. Parameters ---------- mod : IRModule The IRModule representing the network. target : Target The target that the network will be deployed to. params : Optional[Dict[str, NDArray]] The parameters of the networks. name : str The name of the network. input_shape : List[int] The shape of the input tensor. cache_dir : Optional[str] The directory to cache the generated network. If not specified, the cache will be disabled. opt_level : int The optimization level of the compiler. pass_config : Optional[Dict[str, Any]] The pass config of the compiler. disabled_pass : Optional[List[str]] The disabled pass of the compiler. Returns ------- extracted_tasks : List[ExtractedTask] The extracted tasks. """ filename = f'tasks-{target.kind.name}-{name}-{",".join(str(i) for i in input_shape)}.json' extracted_tasks = _load_cache(cache_dir, filename) if extracted_tasks is None: extracted_tasks = extract_task_from_relay( mod=mod, target=target, params=params, opt_level=opt_level, pass_config=pass_config, disabled_pass=disabled_pass, ) extracted_tasks = list(extracted_tasks) _save_cache(cache_dir, filename, extracted_tasks) return extracted_tasks
def apply_fixed_schedules( relay_mod: Union[RelayFunc, IRModule], target: Union[str, Target], params: Optional[Dict[str, NDArray]], schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool], te_filter_func=None, ): """Apply fixed schedules (manually written, without any tunable knobs) as specified by schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. Parameters ---------- mod : Union[RelayFunc, IRModule] The Relay module to apply fixed schedules. target : Union[str, Target] The target used to extract tasks. params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the module. schedule_fn : Callable[[ExtractedTask, Schedule], bool] A callable that is applied for each extracted task and the corresponding default schedule. Returns True if the given schedule should be committed to the database, False otherwise. te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None The filtering function for TE computation If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" - "meta_schedule.DefaultTaskFilterAllowExtern" If it's None, it's the default filtering function If it's a callable, it's the filtering function Returns ------- database : Database The database containing dummy tuning records for manually scheduled traces. """ target = Target(target) if isinstance(target, str) else target extracted_tasks = ms.extract_task_from_relay( relay_mod, target, params, te_filter_func=te_filter_func, ) database = ms.database.MemoryDatabase() for task in extracted_tasks: mod = ms.default_config.mod(task.dispatched[0]) sch = Schedule(mod) if schedule_fn(task, sch): workload = database.commit_workload(mod) tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0], target, []) database.commit_tuning_record(tune_rec) return database
def tune_each_task( mod, target, config, runner, work_dir, params, ): extracted_tasks = ms.extract_task_from_relay(mod, target, params) database = ms.database.JSONDatabase( path_workload=os.path.join(work_dir, "default_database_workload.json"), path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"), ) for task in extracted_tasks: # pylint: disable=protected-access tune_context = ms.tune.Parse._tune_context( tune_context=None, mod=ms.tune.Parse._mod(task.dispatched[0]), target=target, config=config, task_name=task.task_name, space_generator=None, sch_rules=None, postprocs=None, mutator_probs=None, num_threads=os.cpu_count(), ) task_scheduler = ms.tune.Parse._task_scheduler( None, [tune_context], task_weights=[1.0], builder=ms.tune.Parse._builder(None), runner=ms.tune.Parse._runner(runner), database=database, max_trials=config.max_trials_per_task, cost_model=ms.tune.Parse._cost_model(None), measure_callbacks=ms.tune.Parse._callbacks(None), ) # pylint: enable=protected-access task_scheduler.tune() with target, ms.ApplyHistoryBest(database): with PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): return relay_build(mod, target=target, params=params)
def test_extract_task_arm_conv2d_nchwc(): data_shape = (1, 64, 128, 128) weight_shape = (32, 64, 1, 1) bias_shape = (weight_shape[0], ) padding = (1, 1) data = relay.var("data", shape=data_shape, dtype="int8") weight = relay.var("weight", shape=weight_shape, dtype="int8") bias = relay.var("bias", shape=bias_shape, dtype="int32") conv2d = relay.nn.conv2d( data=data, weight=weight, kernel_size=weight_shape[2:], channels=weight_shape[0], padding=padding, strides=(1, 1), out_dtype="int32", ) bias_add = relay.nn.bias_add(conv2d, bias) relay_mod = tvm.IRModule.from_expr(bias_add) weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8") bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32") params = {"weight": weight_np, "bias": bias_np} target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon" extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) tune_tasks = list( filter( lambda task: "conv2d" in task.task_name, extracted_tasks, )) assert len(tune_tasks) == 1 relay_func = list(tune_tasks[0].mod.functions.values())[0] out_type = relay_func.body.checked_type # Check that the output is in NCHWc layout assert list(out_type.shape) == [1, 8, 130, 130, 4]
def extract_and_save_tasks(cache_file): """Extract tuning tasks and cache the nonspatial ones in the given directory. Parameters ---------- cache_file : str The filename of the cached model. Returns ------- None """ mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file) params = load_param_dict(params_bytearray) try: extracted_tasks = ms.extract_task_from_relay(mod, target=args.target, params=params) except tvm.error.TVMError as error: print(str(error)) return task_cache_path = os.path.join( args.task_cache_dir, cache_file.split(".")[0] + "_extracted_tasks.json") is_spatial = tvm.get_global_func("tir.schedule.IsSpatialPrimFunc") with open(task_cache_path, "w", encoding="utf8") as file: for i, task in enumerate(extracted_tasks): subgraph = task.dispatched[0] prim_func = subgraph[subgraph.get_global_vars()[0]] if not is_spatial(prim_func): subgraph_str = save_json(subgraph) json_obj = [task.task_name, json.loads(subgraph_str)] json_str = json.dumps(json_obj) assert "\n" not in json_str, "Failed to generate single line string." if i == len(extracted_tasks) - 1: file.write(json_str) else: file.write(json_str + "\n")
def manual_tir_common(do_tune=False): M, N, K = 1024, 1024, 1024 # pylint: disable=invalid-name data_shape = (M, K) weight_shape = (N, K) data_dtype = "uint8" data = relay.var("data", shape=data_shape, dtype=data_dtype) weight = relay.var("weight", shape=weight_shape, dtype="int8") bias = relay.var("bias", shape=(weight_shape[0], ), dtype="int32") # dense is tuned by the TIR schedule above, bmm is scheduled by TE (topi/x86/batch_matmul.py) dense = relay.nn.dense(data, weight, out_dtype="int32") bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32") out = relay.nn.batch_matmul( relay.cast(relay.expand_dims(bias_add, 0), "uint8"), relay.cast(relay.expand_dims(bias_add, 0), "int8"), out_dtype="int32", ) relay_mod = tvm.IRModule.from_expr(out) target = "llvm -mcpu=cascadelake -num-cores 4" dev = tvm.device(target, 0) data = np.random.uniform(1, 10, size=(M, K)).astype("uint8") weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8") bias_np = np.random.uniform(1, 10, size=(weight_shape[0], )).astype("int32") ref = (relay.create_executor( "vm", mod=relay_mod, device=dev, target=target).evaluate()(*[data, weight_np, bias_np]).numpy()) params = {"weight": weight_np, "bias": bias_np} if do_tune: extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) # Filter out tasks that we don't intend to schedule / tune with TIR. tune_tasks = list( filter( lambda task: "dense" in task.task_name, extracted_tasks, )) config = ms.TuneConfig( strategy="replay_trace", num_trials_per_iter=64, max_trials_per_task=20000, max_trials_global=20000, ) with tempfile.TemporaryDirectory() as work_dir: # postprocs=lambda: [] is important to prevent default post processors from # tampering with the manual schedule. database = ms.tune_extracted_tasks( tune_tasks, config, work_dir=work_dir, postprocs=lambda: [], ) else: def schedule_fn(task, sch): if "dense" not in task.task_name: return False block = sch.get_block("compute") # Looks up schedule_rule annotation. # See the comment in test_tune_relay_manual_tir_vnni(). schedule_rule = sch.get(block).annotations["schedule_rule"] assert "dense_vnni" in schedule_rule schedule_dense(block, M, False, sch) return True database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) with ms.ApplyHistoryBest(database): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): # pylint: disable=W0105 """ The log should say Warning: Cannot find workload: tvmgen_default_fused_expand_dims Warning: Cannot find workload: tvmgen_default_fused_cast Warning: Cannot find workload: tvmgen_default_fused_cast_1 Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul This means batch matmul and others are scheduled by TE, and dense (the one not warned) is found in the meta schedule tuning database during ApplyHistoryBest """ # pylint: enable=W0105 lib = relay.build(relay_mod, target=target, params=params) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) runtime.set_input("data", data) runtime.run() out = runtime.get_output(0).numpy() np.testing.assert_equal(out, ref)
def test_meta_schedule_integration_extract_from_bert_base(): expected = { "fused_nn_dense_2": ( 12, [[64, 3072], [768, 3072], [64, 768]], ), "fused_nn_dense": ( 48, [[64, 768], [768, 768], [64, 768]], ), "fused_nn_dense_1": ( 12, [[64, 768], [3072, 768], [64, 3072]], ), "fused_subtract_add_sqrt_divide_multiply_add": ( 25, [[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]], ), "fused_nn_batch_matmul": ( 24, [[12, 64, 64], [12, 64, 64], [12, 64, 64]], ), "fused_reshape_add_add": ( 24, [[64, 768], [768], [1, 64, 768], [1, 64, 768]], ), "fused_variance": ( 25, [[1, 64, 768], [1, 64, 1], [1, 64, 1]], ), "fused_mean": ( 25, [[1, 64, 768], [1, 64, 1]], ), "fused_reshape_add_reshape_transpose_reshape": ( 12, [[64, 768], [768], [12, 64, 64]], ), "fused_reshape_add_multiply_fast_erf_multiply_add_multiply_reshape": ( 12, [[64, 3072], [3072], [64, 3072]], ), "fused_nn_fast_softmax": ( 12, [[1, 12, 64, 64], [1, 12, 64, 64]], ), "fused_reshape_add_reshape_transpose_reshape_1": ( 24, [[64, 768], [768], [12, 64, 64]], ), "fused_reshape_divide_add": ( 12, [[12, 64, 64], [1, 1, 1, 64], [1, 12, 64, 64]], ), "fused_reshape_transpose_reshape": ( 12, [[12, 64, 64], [64, 768]], ), "fused_nn_dense_add_fast_tanh": ( 1, [[1, 768], [768, 768], [1, 768], [1, 768]], ), "fused_cast_take_add": ( 1, [[1, 64], [30522, 768], [1, 64, 768], [1, 64, 768]], ), "fused_take": ( 1, [[1, 64, 768], [1, 768]], ), "fused_reshape": ( 12, [[1, 12, 64, 64], [12, 64, 64]], ), "fused_reshape_1": ( 24, [[1, 64, 768], [64, 768]], ), } mod, params, _ = get_network(name="bert_base", input_shape=[1, 64]) extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params) assert len(extracted_tasks) == len(expected) for t in extracted_tasks: prim_func = None for _, v in t.dispatched[0].functions.items(): prim_func = v shape = [[int(x) for x in prim_func.buffer_map[b].shape] for b in prim_func.params] assert t.task_name in expected expected_weight, expected_shape = expected[t.task_name] assert expected_weight == t.weight, t.task_name assert expected_shape == shape, t.task_name
def test_meta_schedule_dynamic_loop_extent(): a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32") b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC") mod = IRModule({"main": relay.Function([a], b)}) extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params={}) assert not extracted_tasks
def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): def filter_func(args) -> bool: from tvm.te import create_prim_func # pylint: disable=import-outside-toplevel has_complex_op = False visited = set() def traverse(t): nonlocal has_complex_op assert t.handle is not None if t.handle.value in visited: return if isinstance(t.op, te.PlaceholderOp): pass elif isinstance(t.op, te.ComputeOp): has_complex_op = has_complex_op or any( isinstance(e, tir.Reduce) for e in t.op.body) for x in t.op.input_tensors: traverse(x) visited.add(t.handle.value) for t in args: traverse(t) if not has_complex_op: return None return create_prim_func(args) mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.extract_task_from_relay( mod, target="llvm", params=params, te_filter_func=filter_func, ) expected_task_names = [ "fused_" + s for s in [ "nn_max_pool2d", "nn_adaptive_avg_pool2d", "nn_dense_add", "nn_conv2d_add", "nn_conv2d_add_1", "nn_conv2d_add_2", "nn_conv2d_add_add_nn_relu", "nn_conv2d_add_add_nn_relu_1", "nn_conv2d_add_nn_relu", "nn_conv2d_add_nn_relu_1", "nn_conv2d_add_nn_relu_2", "nn_conv2d_add_nn_relu_3", "nn_conv2d_add_nn_relu_4", "nn_conv2d_add_nn_relu_5", "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", ] ] assert len(extracted_tasks) == len(expected_task_names) for t in extracted_tasks: assert t.task_name in expected_task_names, t.task_name