def print_best(self, log_file, print_mode="schedule"): """Print the best schedule as python schedule API code or CUDA source code. Parameters ---------- log_file : str The name of the log file print_mode: str if "schedule", print the best schedule as python schedule API code. if "cuda", print the best schedule as CUDA source code. Returns ------- code: str The best schedule code in python API or CUDA source code """ inp, _ = load_best_record(log_file, self.workload_key) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)) if print_mode == "schedule": return self.compute_dag.print_python_code_from_state(inp.state) if print_mode == "cuda": assert self.target.kind.name == "cuda" sch, args = self.compute_dag.apply_steps_from_state(inp.state) func = build(sch, args, "cuda") return func.imported_modules[0].get_source() raise ValueError("Invalid print_mode: %s" % print_mode)
def _timed_func(inp_serialized, build_func, verbose): tic = time.time() inp = MeasureInput.deserialize(inp_serialized) task = inp.task error_no = MeasureErrorNo.NO_ERROR error_msg = None args = [] try: sch, args = task.compute_dag.apply_steps_from_state( inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED) # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR error_msg = make_traceback_info() if error_no == 0: dirname = tempfile.mkdtemp() filename = os.path.join(dirname, "tmp_func." + build_func.output_format) try: with transform.PassContext(): func = build_module.build(sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.COMPILE_HOST error_msg = make_traceback_info() else: filename = "" if verbose >= 1: if error_no == MeasureErrorNo.NO_ERROR: print(".", end="") else: print(".E", end="") # Build error return filename, args, error_no, error_msg, time.time() - tic
def timed_func(): tic = time.time() inp = measure_inputs[index] task = inp.task error_no = MeasureErrorNo.NO_ERROR error_msg = None args = [] try: sch, args = task.compute_dag.apply_steps_from_state( inp.state, layout_rewrite=True) # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR error_msg = make_error_msg() if error_no == 0: dirname = tempfile.mkdtemp() filename = os.path.join(dirname, "tmp_func." + build_func.output_format) try: # TODO(merrymercy): Port the unroll pass. with transform.PassContext(): func = build_module.build(sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.COMPILE_HOST error_msg = make_error_msg() else: filename = "" if verbose >= 1: if error_no == MeasureErrorNo.NO_ERROR: print(".", end="") else: print(".E", end="") # Build error return filename, args, error_no, error_msg, time.time() - tic