def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): """Common part for building a configuration""" target, task, config = measure_input with target: s, args = task.instantiate(config) # check invalidity of template and code hash consistency if not config.valid(): raise InstantiationError(config.errors) opts = build_option or {} if check_gpu: # Add verify pass to filter out invalid configs in advance. opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] if cuda_arch: set_cuda_target_arch(cuda_arch) # if target is vta, we need to use vta build if hasattr(measure_input.target, 'device_name') and \ measure_input.target.device_name == 'vta': # pylint: disable=import-outside-toplevel import vta func = vta.build(s, args, target_host=task.target_host) else: with build_config(**opts): func = build(s, args, target_host=task.target_host) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None): """Common part for building a configuration""" target, task, config = measure_input target, task.target_host = Target.check_and_update_host_consist( target, task.target_host) with target: s, args = task.instantiate(config) # check invalidity of template and code hash consistency if not config.valid(): raise InstantiationError(config.errors) opts = build_option or {} if check_gpu: # Add verify pass to filter out invalid configs in advance. opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] # if target is vta, we need to use vta build if (hasattr(measure_input.target, "device_name") and measure_input.target.device_name == "vta"): # pylint: disable=import-outside-toplevel import vta func = vta.build(s, args, target_host=task.target_host) else: with tvm.ir.transform.PassContext(config=opts): func = build(s, args, target_host=task.target_host, runtime=runtime) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
def set_task(self, task): self.task = task if check_remote(task.target, self.key, self.host, self.port): logger.info("Get devices for measurement successfully!") else: raise RuntimeError( "Cannot get remote devices from the tracker. " "Please check the status of tracker by " "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " "and make sure you have free devices on the queue status.") if self.check_correctness: # use llvm cpu to generate a reference input/output # this option works for tuning topi, but might not work for you custom op with _target.create("llvm"): s, arg_bufs = task.instantiate(task.config_space.get(0)) self.ref_input = [ np.random.uniform(size=get_const_tuple(x.shape)).astype( x.dtype) for x in arg_bufs ] func = build(s, arg_bufs, "llvm") tvm_buf = [nd.array(x) for x in self.ref_input] func(*tvm_buf) self.ref_output = [x.asnumpy() for x in tvm_buf]
def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None): """Common part for building a configuration""" target, task, config = measure_input target, task.target_host = Target.canon_target_and_host( target, task.target_host) with target: s, args = task.instantiate(config) # check invalidity of template and code hash consistency if not config.valid(): raise InstantiationError(config.errors) # if target is vta, we need to use vta build if (hasattr(measure_input.target, "device_name") and measure_input.target.device_name == "vta"): # pylint: disable=import-outside-toplevel import vta func = vta.build(s, args, target_host=task.target_host) else: current_pass_context: tvm.ir.transform.PassContext = ( tvm.ir.transform.PassContext.current()) current_config = dict(current_pass_context.config) if build_option is not None: current_config.update(build_option) if "tir.add_lower_pass" in current_config: current_add_lower_pass = list( current_config["tir.add_lower_pass"]) else: current_add_lower_pass = [] if check_gpu: current_add_lower_pass.append( (2, gpu_verify_pass(**check_gpu))) current_config["tir.add_lower_pass"] = current_add_lower_pass with tvm.ir.transform.PassContext( opt_level=current_pass_context.opt_level, required_pass=current_pass_context.required_pass, disabled_pass=current_pass_context.disabled_pass, instruments=current_pass_context.instruments, config=current_config, ): func = build(s, args, target_host=task.target_host, runtime=runtime) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
def _build(lowered_funcs): return build(lowered_funcs, target="llvm")