def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """Converts QNN operators into a sequence of Relay operators that are friendly to HW that do not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions much more efficiently if the convolution or dense operator input datatypes are int16 instead of int8. More details are present at https://github.com/apache/tvm/pull/4277. Parameters ---------- attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized types : list of types List of input and output types Returns ------- result : tvm.relay.Expr The legalized expr """ # Collect the input exprs. data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs shift_data = relay.subtract( relay.cast(data, dtype="int16"), relay.cast(input_zero_point, dtype="int16") ) # If kernel zero point is a scalar we can directly subtract it. if len(types[3].shape) == 0: shift_kernel = relay.subtract( relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16") ) # Otherwise it needs to be broadcast. else: # Determine output axis of kernel for spatial operations. if hasattr(attrs, "kernel_layout"): output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") # For dense operations, broadcast to [N, K] layout. elif isinstance(attrs, relay.op.op_attrs.DenseAttrs): output_axis = 0 # For matrix multiplication instead expand to [K, N] layout. elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs): output_axis = 1 else: raise TVMError( "Legalization of %s is not yet supported with per channel parameters" % str(type(attrs)) ) shift_kernel = relay.nn.bias_add( relay.cast(kernel, dtype="int16"), -relay.cast(kernel_zero_point, dtype="int16"), output_axis, ) new_attrs = {k: attrs[k] for k in attrs.keys()} return relay_op(shift_data, shift_kernel, **new_attrs)
def from_source(src, func_lineno=0): """Parse the src into TIR Parameters ---------- src : str Pruned source of original script func_lineno : Optional[int] The line number of the first line of the script to be parsed Returns ------- functions : PrimFunc or IRModule The PrimFunc or IRModule in IR. """ root = ast.parse(src) parser = HybridParser(src, func_lineno) try: return parser.visit(root) except HybridParserError as e: raise e except TVMError as e: # TVM internal c++ error, we have to process the error message and inject line info inject_e = str(e).split("\n") msg = inject_e[-1].split(":", maxsplit=1)[1].strip() inject_e = inject_e[:-1] inject_e.extend( parser.wrap_line_col(msg, parser.current_lineno, parser.current_col_offset).split("\n")) inject_e[-1] = "TVM" + inject_e[-1][6:] raise TVMError("\n".join(inject_e)) except Exception as e: inject_e = parser.wrap_line_col(str(e), parser.current_lineno, parser.current_col_offset) raise HybridParserError(inject_e)
def extract_tasks( mod, params, target, target_host=None, hardware_params=None, include_simple_tasks=False, dump_workload_to_dag_log=None, opt_level=3, ): """Extract tuning tasks from a relay program. Parameters ---------- mod: tvm.IRModule or relay.function.Function The module or function to tune params: dict of str to numpy array The associated parameters of the program target: Union[tvm.target.Target, str] The compilation target target_host: Optional[Union[tvm.target.Target, str]] The host compilation target hardware_params : Optional[HardwareParams] Hardware parameters used for the search tasks include_simple_tasks: bool Whether to extract simple tasks that do not include complicated ops. dump_workload_to_dag_log: Optional[str] A file to dump an association between the workload keys and the actual DAG opt_level : Optional[int] The optimization level of the task extractions. Returns ------- tasks: List[SearchTask] The tasks in this network weights: List[int] The weight (i.e. the number of appearance) of extracted tasks """ # pylint: disable=import-outside-toplevel target, target_host = Target.canon_target_and_host(target, target_host) # Run the compiler to collect all TOPI calls during compilation. env = TracingEnvironment( TracingMode.EXTRACT_TASK if include_simple_tasks else TracingMode.EXTRACT_COMPLEX_TASK_ONLY ) dispatch_ctx = DispatchContext.current old_verbose = dispatch_ctx.verbose dispatch_ctx.verbose = 0 errors = [] with env: # Wrap build call in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool build_thread = threading.Thread( target=call_all_topi_funcs, args=(mod, params, target, errors, opt_level) ) build_thread.start() build_thread.join() if errors: error_strings = ["Task extraction had the following errors:"] + errors raise TVMError("\n".join(error_strings)) dispatch_ctx.verbose = old_verbose # create search tasks tasks = [] weights = [] for wkl_key, (weight, func_names) in env.wkl_key_to_weight.items(): tasks.append( SearchTask( workload_key=wkl_key, target=target, hardware_params=hardware_params, # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True), task_inputs=( env.wkl_key_to_input_names[wkl_key] if wkl_key in env.wkl_key_to_input_names else None ), task_inputs_save_to_file=True, desc=",".join(func_names), ) ) weights.append(int(weight)) if dump_workload_to_dag_log is not None: with open(dump_workload_to_dag_log, "w") as f: json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f) return tasks, weights