def optimize(mod, target=None, params=None): """Helper function that optimizes a Relay module. Parameters ---------- mod : :py:class:`~tvm.IRModule` The module to build. Using relay.Function is deprecated. target : None, or any multi-target like object, see Target.canon_multi_target For homogeneous compilation, the unique build target. For heterogeneous compilation, a dictionary or list of possible build targets. Defaults to the current target in the environment if None. params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. Returns ------- mod : :py:class:`~tvm.IRModule` The optimized relay module. params : dict The parameters of the final graph. """ if not isinstance(mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") if isinstance(mod, _function.Function): if params: mod = bind_params_by_name(mod, params) mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter func (tvm.relay.function.Function)", DeprecationWarning, ) raw_targets = Target.canon_multi_target_and_host( Target.target_or_current(target)) # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): tophub_context = autotvm.tophub.context(raw_targets) else: tophub_context = autotvm.utils.EmptyContext() with tophub_context: bld_mod = BuildModule() mod, params = bld_mod.optimize(mod, target=raw_targets, params=params) return mod, params
def build( ir_mod, target=None, target_host=None, executor=Executor("graph"), runtime=Runtime("cpp"), workspace_memory_pools=None, params=None, mod_name="default", ): # fmt: off # pylint: disable=line-too-long """Helper function that builds a Relay function to run on TVM graph executor. Parameters ---------- ir_mod : :py:class:`~tvm.IRModule` The IR module to build. Using relay.Function is deprecated. target : None, or any multi-target like object, see Target.canon_multi_target For homogeneous compilation, the unique build target. For heterogeneous compilation, a dictionary or list of possible build targets. Defaults to the current target in the environment if None. target_host : None, or any target like object, see Target.canon_target Host compilation target, if target is device. executor : Optional[Executor] The executor configuration with which to build the model. Defaults to "graph" if no executor specified. runtime : Optional[Runtime] Runtime configuration to use when building the model. Defaults to "cpp" if no runtime specified. workspace_memory_pools : Optional[WorkspaceMemoryPools] The object that contains an Array of PoolInfo objects that hold properties of workspace pools that could be used by the inference. params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. mod_name: Optional[str] The module name we will build Returns ------- factory_module : tvm.relay.backend.executor_factory.ExecutorFactoryModule The runtime factory for the TVM graph executor. """ # pylint: enable=line-too-long # fmt: on if not isinstance(ir_mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") if isinstance(ir_mod, _function.Function): if params: ir_mod = bind_params_by_name(ir_mod, params) ir_mod = IRModule.from_expr(ir_mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning, ) raw_targets = Target.canon_multi_target_and_host( Target.target_or_current(target), target_host) assert len(raw_targets) > 0 target_host = raw_targets[0].host # All of this logic is to raise deprecation warnings for various parameters # TODO(Mousius) Remove these after some time deprecated_params_target = target_host or list(raw_targets)[0] deprecated_executor, deprecated_runtime = _reconstruct_from_deprecated_options( deprecated_params_target) if deprecated_executor: executor = deprecated_executor if deprecated_runtime: runtime = deprecated_runtime # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): tophub_context = autotvm.tophub.context(list(raw_targets)) else: tophub_context = autotvm.utils.EmptyContext() with tophub_context: bld_mod = BuildModule() graph_json, runtime_mod, params = bld_mod.build( mod=ir_mod, target=raw_targets, params=params, executor=executor, runtime=runtime, workspace_memory_pools=workspace_memory_pools, mod_name=mod_name, ) func_metadata = bld_mod.get_function_metadata() devices = bld_mod.get_devices() lowered_ir_mods = bld_mod.get_irmodule() executor_codegen_metadata = bld_mod.get_executor_codegen_metadata() if str(executor) == "aot": executor_factory = _executor_factory.AOTExecutorFactoryModule( ir_mod, lowered_ir_mods, raw_targets, executor, runtime, runtime_mod, mod_name, params, func_metadata, executor_codegen_metadata, devices, ) elif str(executor) == "graph": executor_factory = _executor_factory.GraphExecutorFactoryModule( ir_mod, raw_targets, executor, graph_json, runtime_mod, mod_name, params, func_metadata, ) else: assert False, "Executor " + executor + " not supported" return executor_factory