示例#1
0
def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
    """Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
    not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
    much more efficiently if the convolution or dense operator input datatypes are int16 instead of
    int8. More details are present at https://github.com/apache/tvm/pull/4277.

    Parameters
    ----------
    attrs : tvm.ir.Attrs
        Attributes of current convolution
    inputs : list of tvm.relay.Expr
        The args of the Relay expr to be legalized
    types : list of types
        List of input and output types

    Returns
    -------
    result : tvm.relay.Expr
        The legalized expr
    """

    # Collect the input exprs.
    data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs

    shift_data = relay.subtract(
        relay.cast(data, dtype="int16"), relay.cast(input_zero_point, dtype="int16")
    )
    # If kernel zero point is a scalar we can directly subtract it.
    if len(types[3].shape) == 0:
        shift_kernel = relay.subtract(
            relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16")
        )
    # Otherwise it needs to be broadcast.
    else:
        # Determine output axis of kernel for spatial operations.
        if hasattr(attrs, "kernel_layout"):
            output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O")
        # For dense operations, broadcast to [N, K] layout.
        elif isinstance(attrs, relay.op.op_attrs.DenseAttrs):
            output_axis = 0
        # For matrix multiplication instead expand to [K, N] layout.
        elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs):
            output_axis = 1
        else:
            raise TVMError(
                "Legalization of %s is not yet supported with per channel parameters"
                % str(type(attrs))
            )

        shift_kernel = relay.nn.bias_add(
            relay.cast(kernel, dtype="int16"),
            -relay.cast(kernel_zero_point, dtype="int16"),
            output_axis,
        )
    new_attrs = {k: attrs[k] for k in attrs.keys()}
    return relay_op(shift_data, shift_kernel, **new_attrs)
示例#2
0
def from_source(src, func_lineno=0):
    """Parse the src into TIR

    Parameters
    ----------
    src : str
        Pruned source of original script
    func_lineno : Optional[int]
        The line number of the first line of the script to be parsed
    Returns
    -------
    functions : PrimFunc or IRModule
        The PrimFunc or IRModule in IR.
    """

    root = ast.parse(src)
    parser = HybridParser(src, func_lineno)

    try:
        return parser.visit(root)
    except HybridParserError as e:
        raise e
    except TVMError as e:
        # TVM internal c++ error, we have to process the error message and inject line info
        inject_e = str(e).split("\n")
        msg = inject_e[-1].split(":", maxsplit=1)[1].strip()
        inject_e = inject_e[:-1]
        inject_e.extend(
            parser.wrap_line_col(msg, parser.current_lineno,
                                 parser.current_col_offset).split("\n"))
        inject_e[-1] = "TVM" + inject_e[-1][6:]
        raise TVMError("\n".join(inject_e))
    except Exception as e:
        inject_e = parser.wrap_line_col(str(e), parser.current_lineno,
                                        parser.current_col_offset)
        raise HybridParserError(inject_e)
示例#3
0
def extract_tasks(
    mod,
    params,
    target,
    target_host=None,
    hardware_params=None,
    include_simple_tasks=False,
    dump_workload_to_dag_log=None,
    opt_level=3,
):
    """Extract tuning tasks from a relay program.

    Parameters
    ----------
    mod: tvm.IRModule or relay.function.Function
        The module or function to tune
    params: dict of str to numpy array
        The associated parameters of the program
    target: Union[tvm.target.Target, str]
        The compilation target
    target_host: Optional[Union[tvm.target.Target, str]]
        The host compilation target
    hardware_params : Optional[HardwareParams]
        Hardware parameters used for the search tasks
    include_simple_tasks: bool
        Whether to extract simple tasks that do not include complicated ops.
    dump_workload_to_dag_log: Optional[str]
        A file to dump an association between the workload keys and the actual DAG
    opt_level : Optional[int]
        The optimization level of the task extractions.

    Returns
    -------
    tasks: List[SearchTask]
        The tasks in this network
    weights: List[int]
        The weight (i.e. the number of appearance) of extracted tasks
    """
    # pylint: disable=import-outside-toplevel
    target, target_host = Target.canon_target_and_host(target, target_host)

    # Run the compiler to collect all TOPI calls during compilation.
    env = TracingEnvironment(
        TracingMode.EXTRACT_TASK if include_simple_tasks else TracingMode.EXTRACT_COMPLEX_TASK_ONLY
    )

    dispatch_ctx = DispatchContext.current
    old_verbose = dispatch_ctx.verbose
    dispatch_ctx.verbose = 0

    errors = []
    with env:
        # Wrap build call in a new thread to avoid the conflict
        # between python's multiprocessing and tvm's thread pool
        build_thread = threading.Thread(
            target=call_all_topi_funcs, args=(mod, params, target, errors, opt_level)
        )
        build_thread.start()
        build_thread.join()

    if errors:
        error_strings = ["Task extraction had the following errors:"] + errors
        raise TVMError("\n".join(error_strings))

    dispatch_ctx.verbose = old_verbose

    # create search tasks
    tasks = []
    weights = []
    for wkl_key, (weight, func_names) in env.wkl_key_to_weight.items():
        tasks.append(
            SearchTask(
                workload_key=wkl_key,
                target=target,
                hardware_params=hardware_params,
                # When auto scheduler is used in end to end network, try to apply layout rewrite
                # to improve the overall performance
                layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
                task_inputs=(
                    env.wkl_key_to_input_names[wkl_key]
                    if wkl_key in env.wkl_key_to_input_names
                    else None
                ),
                task_inputs_save_to_file=True,
                desc=",".join(func_names),
            )
        )
        weights.append(int(weight))

    if dump_workload_to_dag_log is not None:
        with open(dump_workload_to_dag_log, "w") as f:
            json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f)

    return tasks, weights