Esempio n. 1
0
 def _func():
     with env, _autotvm_silencer(), transform.PassContext(
             config=pass_config,
             disabled_pass=disabled_pass,
             opt_level=opt_level,
     ):
         compiler = vm.VMCompiler()
         if params:
             compiler.set_params(params)
         compiler.lower(mod, target)
Esempio n. 2
0
def _timed_func(inp_serialized, build_func, verbose):
    tic = time.time()
    inp = MeasureInput.deserialize(inp_serialized)
    task = inp.task

    error_no = MeasureErrorNo.NO_ERROR
    error_msg = None
    args = []

    try:
        sch, args = task.compute_dag.apply_steps_from_state(
            inp.state,
            layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED)
    # pylint: disable=broad-except
    except Exception:
        error_no = MeasureErrorNo.INSTANTIATION_ERROR
        error_msg = make_traceback_info()

    if error_no == 0:
        dirname = tempfile.mkdtemp()
        filename = os.path.join(dirname,
                                "tmp_func." + build_func.output_format)

        try:
            with transform.PassContext():
                func = build_module.build(sch,
                                          args,
                                          target=task.target,
                                          target_host=task.target_host)
            func.export_library(filename, build_func)
        # pylint: disable=broad-except
        except Exception:
            error_no = MeasureErrorNo.COMPILE_HOST
            error_msg = make_traceback_info()
    else:
        filename = ""

    if verbose >= 1:
        if error_no == MeasureErrorNo.NO_ERROR:
            print(".", end="")
        else:
            print(".E", end="")  # Build error

    return filename, args, error_no, error_msg, time.time() - tic
Esempio n. 3
0
    def timed_func():
        tic = time.time()
        inp = measure_inputs[index]
        task = inp.task

        error_no = MeasureErrorNo.NO_ERROR
        error_msg = None
        args = []

        try:
            sch, args = task.compute_dag.apply_steps_from_state(
                inp.state, layout_rewrite=True)
        # pylint: disable=broad-except
        except Exception:
            error_no = MeasureErrorNo.INSTANTIATION_ERROR
            error_msg = make_error_msg()

        if error_no == 0:
            dirname = tempfile.mkdtemp()
            filename = os.path.join(dirname,
                                    "tmp_func." + build_func.output_format)

            try:
                # TODO(merrymercy): Port the unroll pass.
                with transform.PassContext():
                    func = build_module.build(sch,
                                              args,
                                              target=task.target,
                                              target_host=task.target_host)
                func.export_library(filename, build_func)
            # pylint: disable=broad-except
            except Exception:
                error_no = MeasureErrorNo.COMPILE_HOST
                error_msg = make_error_msg()
        else:
            filename = ""

        if verbose >= 1:
            if error_no == MeasureErrorNo.NO_ERROR:
                print(".", end="")
            else:
                print(".E", end="")  # Build error
        return filename, args, error_no, error_msg, time.time() - tic
Esempio n. 4
0
def extract_task_from_relay(
    mod: IRModule,
    target: Target,
    params: Optional[Dict[str, NDArray]] = None,
    *,
    opt_level: int = 3,
    pass_config: Optional[Dict[str, Any]] = None,
    disabled_pass: Optional[List[str]] = None,
) -> List[ExtractedTask]:
    """Extract tuning tasks from a relay program.

    Parameters
    ----------
    mod : IRModule
        The module or function to tune
    target : tvm.target.Target
        The compilation target
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the program
    opt_level : int
        The optimization level of the compiler
    pass_config : Optional[Dict[str, Any]]
        The pass config of the compiler
    disabled_pass : Optional[List[str]]
        The list of disabled passes of the compiler

    Returns
    -------
    tasks: List[ExtractedTask]
        The tasks extracted from this network
    """
    # pylint: disable=import-outside-toplevel
    from tvm.relay import Function as RelayFunc

    # pylint: enable=import-outside-toplevel

    extract_task_func = get_global_func(
        "relay.backend.MetaScheduleExtractTask",
        allow_missing=False,
    )

    if isinstance(mod, RelayFunc):
        mod = IRModule.from_expr(mod)
    if not isinstance(target, Target):
        target = Target(target)
    if disabled_pass is None:
        disabled_pass = []
    if pass_config is None:
        pass_config = {"relay.backend.use_meta_schedule": True}
    if params is None:
        params = {}
    relay_params = {}
    for name, param in params.items():
        if isinstance(param, np.ndarray):
            param = nd.array(param)
        relay_params[name] = param

    with autotvm_silencer(), target, transform.PassContext(
        opt_level=opt_level,
        config=pass_config,
        disabled_pass=disabled_pass,
    ):
        return list(extract_task_func(mod, target, relay_params))
Esempio n. 5
0
def extract_task_from_relay(
    mod: IRModule,
    target: Target,
    params: Optional[Dict[str, NDArray]] = None,
    *,
    opt_level: int = 3,
    pass_config: Optional[Dict[str, Any]] = None,
    disabled_pass: Optional[List[str]] = None,
    te_filter_func: Union[str, None, Callable[[List[Tensor]],
                                              PrimFunc]] = None,
) -> List[ExtractedTask]:
    """Extract tuning tasks from a relay program.

    Parameters
    ----------
    mod : IRModule
        The module or function to tune
    target : tvm.target.Target
        The compilation target
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the program
    opt_level : int
        The optimization level of the compiler
    pass_config : Optional[Dict[str, Any]]
        The pass config of the compiler
    disabled_pass : Optional[List[str]]
        The list of disabled passes of the compiler
    te_filter_func : Callable[[List[tvm.te.Tensor]], bool]
        The filter function to filter out the extracted tasks
        If it's a string, it's the name of the filtering function. Built in functions are
          - "meta_schedule.DefaultTaskFilter"
          - "meta_schedule.DefaultTaskFilterAllowExtern"
        If it's None, it's the default filtering function
        If it's a callable, it's the filtering function

    Returns
    -------
    tasks: List[ExtractedTask]
        The tasks extracted from this network
    """
    # pylint: disable=import-outside-toplevel
    from tvm import autotvm
    from tvm.relay import Function as RelayFunc

    # pylint: enable=import-outside-toplevel

    if isinstance(te_filter_func, str):
        te_filter_func = get_global_func(te_filter_func)
    extract_task_func = get_global_func(
        "relay.backend.MetaScheduleExtractTask",
        allow_missing=False,
    )

    if isinstance(mod, RelayFunc):
        mod = IRModule.from_expr(mod)
    if not isinstance(target, Target):
        target = Target(target)
    if disabled_pass is None:
        disabled_pass = []
    if pass_config is None:
        pass_config = {"relay.backend.use_meta_schedule": True}
    if params is None:
        params = {}
    relay_params = {}
    for name, param in params.items():
        if isinstance(param, np.ndarray):
            param = nd.array(param)
        relay_params[name] = param

    with target, autotvm_silencer(), transform.PassContext(
            opt_level=opt_level,
            config=pass_config,
            disabled_pass=disabled_pass,
    ):
        if target.kind.name != "cuda" and isinstance(
                autotvm.DispatchContext.current, autotvm.FallbackContext):
            tophub_context = autotvm.tophub.context(target)
        else:
            tophub_context = autotvm.utils.EmptyContext()
        with tophub_context:
            return list(
                extract_task_func(mod, target, relay_params, te_filter_func))
Esempio n. 6
0
def extract_task_from_relay(
    mod: Union[IRModule, RelayFunc],
    target: Target,
    params: Optional[Dict[str, NDArray]] = None,
    *,
    opt_level: int = 3,
    pass_config: Optional[Dict[str, Any]] = None,
    disabled_pass: Optional[List[str]] = None,
) -> List[ExtractedTask]:
    """Extract tuning tasks from a relay program.

    Parameters
    ----------
    mod : Union[tvm.IRModule, tvm.relay.Function]
        The module or function to tune
    target : tvm.target.Target
        The compilation target
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the program
    opt_level : int
        The optimization level of the compiler
    pass_config : Optional[Dict[str, Any]]
        The pass config of the compiler
    disabled_pass : Optional[List[str]]
        The list of disabled passes of the compiler

    Returns
    -------
    tasks: List[ExtractedTask]
        The tasks extracted from this network
    """

    extract_task_func = get_global_func(
        "relay.backend.MetaScheduleExtractTask")
    assert extract_task_func

    target = Target(target) if isinstance(target, str) else target

    relay_params = {}
    for name, param in params.items():
        if isinstance(param, np.ndarray):
            param = nd.array(param)
        relay_params[name] = param

    if disabled_pass is None:
        disabled_pass = []
    if pass_config is None:
        pass_config = {"relay.backend.use_meta_schedule": True}

    if isinstance(mod, RelayFunc):
        mod = IRModule.from_expr(mod)
    if not isinstance(target, Target):
        target = Target(target)

    with target, transform.PassContext(
            opt_level=opt_level,
            config=pass_config,
            disabled_pass=disabled_pass,
    ):
        tasks = extract_task_func(mod, target, relay_params)
        # Tasks are extracted via post order visit, return the reversed list.
        return list(reversed(tasks))