示例#1
0
    def print_best(self, log_file, print_mode="schedule"):
        """Print the best schedule as python schedule API code or CUDA source code.

        Parameters
        ----------
        log_file : str
           The name of the log file
        print_mode: str
           if "schedule", print the best schedule as python schedule API code.
           if "cuda", print the best schedule as CUDA source code.

        Returns
        -------
        code: str
            The best schedule code in python API or CUDA source code
        """
        inp, _ = load_best_record(log_file, self.workload_key)
        if inp is None:
            raise RuntimeError(
                "Cannot find any valid schedule for %s in file %s" %
                (self.workload_key, log_file))

        if print_mode == "schedule":
            return self.compute_dag.print_python_code_from_state(inp.state)
        if print_mode == "cuda":
            assert self.target.kind.name == "cuda"
            sch, args = self.compute_dag.apply_steps_from_state(inp.state)
            func = build(sch, args, "cuda")
            return func.imported_modules[0].get_source()
        raise ValueError("Invalid print_mode: %s" % print_mode)
示例#2
0
文件: measure.py 项目: ybai62868/tvm
def _timed_func(inp_serialized, build_func, verbose):
    tic = time.time()
    inp = MeasureInput.deserialize(inp_serialized)
    task = inp.task

    error_no = MeasureErrorNo.NO_ERROR
    error_msg = None
    args = []

    try:
        sch, args = task.compute_dag.apply_steps_from_state(
            inp.state,
            layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED)
    # pylint: disable=broad-except
    except Exception:
        error_no = MeasureErrorNo.INSTANTIATION_ERROR
        error_msg = make_traceback_info()

    if error_no == 0:
        dirname = tempfile.mkdtemp()
        filename = os.path.join(dirname,
                                "tmp_func." + build_func.output_format)

        try:
            with transform.PassContext():
                func = build_module.build(sch,
                                          args,
                                          target=task.target,
                                          target_host=task.target_host)
            func.export_library(filename, build_func)
        # pylint: disable=broad-except
        except Exception:
            error_no = MeasureErrorNo.COMPILE_HOST
            error_msg = make_traceback_info()
    else:
        filename = ""

    if verbose >= 1:
        if error_no == MeasureErrorNo.NO_ERROR:
            print(".", end="")
        else:
            print(".E", end="")  # Build error

    return filename, args, error_no, error_msg, time.time() - tic
示例#3
0
    def timed_func():
        tic = time.time()
        inp = measure_inputs[index]
        task = inp.task

        error_no = MeasureErrorNo.NO_ERROR
        error_msg = None
        args = []

        try:
            sch, args = task.compute_dag.apply_steps_from_state(
                inp.state, layout_rewrite=True)
        # pylint: disable=broad-except
        except Exception:
            error_no = MeasureErrorNo.INSTANTIATION_ERROR
            error_msg = make_error_msg()

        if error_no == 0:
            dirname = tempfile.mkdtemp()
            filename = os.path.join(dirname,
                                    "tmp_func." + build_func.output_format)

            try:
                # TODO(merrymercy): Port the unroll pass.
                with transform.PassContext():
                    func = build_module.build(sch,
                                              args,
                                              target=task.target,
                                              target_host=task.target_host)
                func.export_library(filename, build_func)
            # pylint: disable=broad-except
            except Exception:
                error_no = MeasureErrorNo.COMPILE_HOST
                error_msg = make_error_msg()
        else:
            filename = ""

        if verbose >= 1:
            if error_no == MeasureErrorNo.NO_ERROR:
                print(".", end="")
            else:
                print(".E", end="")  # Build error
        return filename, args, error_no, error_msg, time.time() - tic