コード例 #1
0
ファイル: frontends.py プロジェクト: chenghanpeng/tvm
    def load(self, path, shape_dict=None, **kwargs):
        model = lazy_import("tflite.Model")

        with open(path, "rb") as tf_graph:
            content = tf_graph.read()

        # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
        try:
            tflite_model = model.Model.GetRootAsModel(content, 0)
        except AttributeError:
            tflite_model = model.GetRootAsModel(content, 0)

        try:
            version = tflite_model.Version()
            logger.debug("tflite version %s", version)
        except Exception:
            raise TVMCException("input file not tflite")

        if version != 3:
            raise TVMCException("input file not tflite version 3")

        logger.debug(
            "parse TFLite model and convert into Relay computation graph")
        mod, params = relay.frontend.from_tflite(tflite_model,
                                                 shape_dict=shape_dict,
                                                 **kwargs)
        return mod, params
コード例 #2
0
ファイル: registry.py プロジェクト: chenghanpeng/tvm
def reconstruct_registry_entity(args, registry):
    """Reconstructs an entity from arguments generated from a registry"""
    possible_names = registry.list_registered()
    name = getattr(args, registry.flag_registry_name)
    if name is None:
        return None

    if name not in possible_names:
        raise TVMCException(
            f'{registry.flag_registry_name.title()} "{name}" is not defined')

    reconstructed = {
        possible_name: _reconstruct_registry_options(args, registry,
                                                     possible_name)
        for possible_name in possible_names
    }

    for possible_name in possible_names:
        if possible_name != name and reconstructed[possible_name]:
            first_option = list(reconstructed[possible_name])[0]
            raise TVMCException(
                f"Passed --{registry.flag_registry_name}-{possible_name}-{first_option} "
                f"but did not specify {possible_name} executor")

    return registry(name, reconstructed[name])
コード例 #3
0
ファイル: frontends.py プロジェクト: chenghanpeng/tvm
    def load(self, path, shape_dict=None, **kwargs):
        # pylint: disable=C0415
        import paddle

        paddle.enable_static()
        paddle.disable_signal_handler()

        if not os.path.exists(path):
            raise TVMCException("File {} is not exist.".format(path))
        if not path.endswith(".pdmodel"):
            raise TVMCException(
                "Path of model file should be endwith suffixes '.pdmodel'.")
        prefix = "".join(path.strip().split(".")[:-1])
        params_file_path = prefix + ".pdiparams"
        if not os.path.exists(params_file_path):
            raise TVMCException(
                "File {} is not exist.".format(params_file_path))

        # pylint: disable=E1101
        exe = paddle.static.Executor(paddle.CPUPlace())
        prog, _, _ = paddle.static.load_inference_model(prefix, exe)

        return relay.frontend.from_paddle(prog,
                                          shape_dict=shape_dict,
                                          **kwargs)
コード例 #4
0
ファイル: target.py プロジェクト: a1nc/tvm
def validate_targets(parse_targets, additional_target_options=None):
    """
    Apply a series of validations in the targets provided via CLI.
    """
    tvm_target_kinds = tvm.target.Target.list_kinds()
    targets = [t["name"] for t in parse_targets]

    if len(targets) > len(set(targets)):
        raise TVMCException("Duplicate target definitions are not allowed")

    if targets[-1] not in tvm_target_kinds:
        tvm_target_names = ", ".join(tvm_target_kinds)
        raise TVMCException(
            f"The last target needs to be a TVM target. Choices: {tvm_target_names}"
        )

    tvm_targets = [t for t in targets if t in _valid_target_kinds()]
    if len(tvm_targets) > 2:
        verbose_tvm_targets = ", ".join(tvm_targets)
        raise TVMCException(
            "Only two of the following targets can be used at a time. "
            f"Found: {verbose_tvm_targets}.")

    if additional_target_options is not None:
        for target_name in additional_target_options:
            if not any([
                    target for target in parse_targets
                    if target["name"] == target_name
            ]):
                first_option = list(
                    additional_target_options[target_name].keys())[0]
                raise TVMCException(
                    f"Passed --target-{target_name}-{first_option}"
                    f" but did not specify {target_name} target")
コード例 #5
0
ファイル: pass_config.py プロジェクト: wang910/tvm
def parse_configs(input_configs):
    """Parse configuration values set via command line.

    Parameters
    ----------
    input_configs: list of str
        list of configurations provided via command line.

    Returns
    -------
    pass_context_configs: dict
        a dict containing key-value configs to be used in the PassContext.
    """
    if not input_configs:
        return {}

    all_configs = tvm.ir.transform.PassContext.list_configs()
    supported_config_types = ("IntImm", "runtime.String")
    supported_configs = [
        name for name in all_configs.keys()
        if all_configs[name]["type"] in supported_config_types
    ]

    pass_context_configs = {}

    for config in input_configs:
        if not config:
            raise TVMCException(
                f"Invalid format for configuration '{config}', use <config>=<value>"
            )

        # Each config is expected to be provided as "name=value"
        try:
            name, value = config.split("=")
            name = name.strip()
            value = value.strip()
        except ValueError:
            raise TVMCException(
                f"Invalid format for configuration '{config}', use <config>=<value>"
            )

        if name not in all_configs:
            raise TVMCException(
                f"Configuration '{name}' is not defined in TVM. "
                f"These are the existing configurations: {', '.join(all_configs)}"
            )

        if name not in supported_configs:
            raise TVMCException(
                f"Configuration '{name}' uses a data type not supported by TVMC. "
                f"The following configurations are supported: {', '.join(supported_configs)}"
            )

        parsed_value = get_pass_config_value(name, value,
                                             all_configs[name]["type"])
        pass_context_configs[name] = parsed_value

    return pass_context_configs
コード例 #6
0
ファイル: model.py プロジェクト: lfengad/incubator-tvm
    def export_package(
        self,
        executor_factory: Union[GraphExecutorFactoryModule, Executable],
        package_path: Optional[str] = None,
        cross: Optional[Union[str, Callable]] = None,
        cross_options: Optional[str] = None,
        output_format: str = "so",
    ):
        """Save this TVMCModel to file.
        Parameters
        ----------
        executor_factory : GraphExecutorFactoryModule
            The factory containing the compiled artifacts needed to run this model.
        package_path : str, None
            Where the model should be saved. Note that it will be packaged as a .tar file.
            If not provided, the package will be saved to a generically named file in tmp.
        cross : str or callable object, optional
            Function that performs the actual compilation.
        cross_options : str, optional
            Command line options to be passed to the cross compiler.
        output_format : str
            How to save the modules function library. Must be one of "so" and "tar" to save
            using the classic format or "mlf" to save using the Model Library Format.

        Returns
        -------
        package_path : str
            The path that the package was saved to.
        """
        if output_format not in ["so", "tar", "mlf"]:
            raise TVMCException(
                "Only 'so', 'tar', and 'mlf' output formats are supported.")

        if output_format == "mlf" and cross:
            raise TVMCException(
                "Specifying the MLF output and a cross compiler is not supported."
            )

        if isinstance(executor_factory, Executable):
            package_path = self.export_vm_format(executor_factory,
                                                 package_path, output_format)
        elif output_format in ["so", "tar"]:
            package_path = self.export_classic_format(executor_factory,
                                                      package_path, cross,
                                                      cross_options,
                                                      output_format)
        elif output_format == "mlf":
            if export_model_library_format:
                package_path = export_model_library_format(
                    executor_factory, package_path)
            else:
                raise Exception(
                    "micro tvm is not enabled. Set USE_MICRO to ON in config.cmake"
                )

        return package_path
コード例 #7
0
ファイル: workspace_pools.py プロジェクト: chenghanpeng/tvm
def _parse_target_string(attr_str, targets, pool_name):
    if attr_str is None:
        raise TVMCException(f'No target specified for Workspace Pool "{pool_name}"')

    target_name = [re.split(",", attr_str)]
    matched_targets = [
        target
        for target in targets
        if any(target.kind.name in target_string_match for target_string_match in target_name[0])
    ]
    if not matched_targets:
        raise TVMCException(f'Workspace Pool "{pool_name}" using undefined Target "{target_name}"')
    return matched_targets
コード例 #8
0
ファイル: frontends.py プロジェクト: chenghanpeng/tvm
        def _validate_text(text):
            """Check the provided file contents.
            The relay.txt artifact contained in the MLF is missing the version header and
            the metadata which is required to use meta[relay.Constant]."""

            if re.compile(r".*\#\[version\.*").match(text) is None:
                raise TVMCException(
                    "The relay model does not include the required version information."
                )
            if re.compile(r".*meta\[.+\].*", re.DOTALL).match(text):
                if "#[metadata]" not in text:
                    raise TVMCException(
                        "The relay model does not include the required #[metadata] section. "
                        "Use ir_mod.astext(show_meta_data=True) to export compatible code."
                    )
コード例 #9
0
ファイル: frontends.py プロジェクト: chenghanpeng/tvm
def guess_frontend(path: str):
    """
    This function will try to imply which framework is being used,
    based on the extension of the file provided in the path parameter.

    Parameters
    ----------
    path : str
        The path to the model file.

    Returns
    -------
    frontend : tvm.driver.tvmc.Frontend
        An instance of the frontend that matches with
        the file extension provided in `path`.

    """

    suffix = Path(path).suffix.lower()
    if suffix.startswith("."):
        suffix = suffix[1:]

    for frontend in ALL_FRONTENDS:
        if suffix in frontend.suffixes():
            return frontend()

    raise TVMCException(
        "failed to infer the model format. Please specify --model-format")
コード例 #10
0
def load_function(full_name):
    """Dynamic loading a function by the full name.
    Parameters
    ----------
    full_name: str
        The name of a PackedFunc or a string of the form "path.to.module.func"
        that indicates the module that can be imported.
        You must be aware of the load order here, it first tries to find it via
        TVM global function, if not find, try to import it by "importlib.import_module".
    Returns
    -------
    func: function or PackedFunc
        The loaded fucntion.
    """
    global_func = tvm.get_global_func(full_name, allow_missing=True)
    if global_func is not None:
        return global_func

    # split full name "path.to.module.func" into two parts ["path.to.module", "func"]
    module_name, func_name = full_name.rsplit(".", 1)

    # import module and find the function
    module = importlib.import_module(module_name)
    if hasattr(module, func_name):
        return getattr(module, func_name)

    raise TVMCException(
        f"No function '{func_name}' found in module '{module_name}'.")
コード例 #11
0
def read_and_convert_json_into_dict(config_args):
    """Read json configuration file and return a dictionary with all parameters

    Parameters
    ----------
    args: argparse.Namespace
        Arguments from command line parser holding the json file path.

    Returns
    -------
    dictionary
        dictionary with all the json arguments keys and values

    """
    try:
        if ".json" not in config_args.config:
            config_args.config = config_args.config.strip() + ".json"
        if os.path.isfile(config_args.config):
            json_config_file = config_args.config
        else:
            config_dir = get_configs_json_dir()
            json_config_file = find_json_file(config_args.config, config_dir)
        return json.load(open(json_config_file, "rb"))

    except FileNotFoundError:
        raise TVMCException(
            f"File {config_args.config} does not exist at {config_dir} or is wrong format."
        )
コード例 #12
0
    def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None):
        self._tmp_dir = utils.tempdir()
        self.package_path = package_path
        self.import_package(self.package_path)

        if project_dir and self.type != "mlf":
            raise TVMCException("Setting 'project_dir' is only allowed when importing a MLF.!")
        self.project_dir = project_dir
コード例 #13
0
    def import_package(self, package_path: str):
        """Load a TVMCPackage from a previously exported TVMCModel.

        Parameters
        ----------
        package_path : str
            The path to the saved TVMCPackage.
        """
        temp = self._tmp_dir
        t = tarfile.open(package_path)
        t.extractall(temp.relpath("."))

        if os.path.exists(temp.relpath("metadata.json")):
            # Model Library Format (MLF)
            self.lib_name = None
            self.lib_path = None
            with open(temp.relpath("metadata.json")) as metadata_json:
                metadata = json.load(metadata_json)

            has_graph_executor = "graph" in metadata["executors"]
            graph = temp.relpath("executor-config/graph/graph.json") if has_graph_executor else None
            params = temp.relpath("parameters/default.params")

            self.type = "mlf"
        else:
            # Classic format
            lib_name_so = "mod.so"
            lib_name_tar = "mod.tar"
            if os.path.exists(temp.relpath(lib_name_so)):
                self.lib_name = lib_name_so
            elif os.path.exists(temp.relpath(lib_name_tar)):
                self.lib_name = lib_name_tar
            else:
                raise TVMCException("Couldn't find exported library in the package.")
            self.lib_path = temp.relpath(self.lib_name)

            graph = temp.relpath("mod.json")
            params = temp.relpath("mod.params")

            self.type = "classic"

        with open(params, "rb") as param_file:
            self.params = bytearray(param_file.read())

        if graph is not None:
            with open(graph) as graph_file:
                self.graph = graph_file.read()
        else:
            self.graph = None
コード例 #14
0
ファイル: model.py プロジェクト: lfengad/incubator-tvm
 def __init__(
     self,
     mod: Optional[tvm.IRModule] = None,
     params: Optional[Dict[str, tvm.nd.NDArray]] = None,
     model_path: Optional[str] = None,
 ):
     if (mod is None or params is None) and (model_path is None):
         raise TVMCException("Either mod and params must be provided "
                             "or a path to a previously saved TVMCModel")
     self._tmp_dir = utils.tempdir()
     if model_path is not None:
         self.load(model_path)
     else:
         self.mod = mod
         self.params = params if params else {}
コード例 #15
0
def get_codegen_by_target(name):
    """Return a codegen entry by name.

    Parameters
    ----------
    name : str
        The name of the target for which the codegen info should be retrieved.

    Returns
    -------
    dict
        requested target codegen information
    """
    try:
        return REGISTERED_CODEGEN[name]
    except KeyError:
        raise TVMCException("Composite target %s is not defined in TVMC." %
                            name)
コード例 #16
0
ファイル: frontends.py プロジェクト: chenghanpeng/tvm
    def load(self, path, shape_dict=None, **kwargs):
        torch = lazy_import("torch")

        if shape_dict is None:
            raise TVMCException("--input-shapes must be specified for %s" %
                                self.name())

        traced_model = torch.jit.load(path)
        traced_model.eval()  # Switch to inference mode

        # Convert shape dictionary to list for Pytorch frontend compatibility
        input_shapes = list(shape_dict.items())

        logger.debug(
            "parse Torch model and convert into Relay computation graph")
        return relay.frontend.from_pytorch(traced_model,
                                           input_shapes,
                                           keep_quantized_weight=True,
                                           **kwargs)
コード例 #17
0
ファイル: workspace_pools.py プロジェクト: chenghanpeng/tvm
def _parse_target_attributes_of_pool_name(attr_str, targets):
    if not targets or attr_str is None:
        return {}

    target_attributes = {}
    for pool_values in attr_str:
        pool_name, target_name, target_value = re.split(":", pool_values)
        if pool_name not in target_attributes:
            target_attributes[pool_name] = {}

        matched_targets = [target for target in targets if target_name == target.kind.name]
        if matched_targets:
            target_attributes[pool_name][matched_targets[0]] = target_value
        else:
            raise TVMCException(
                "The workspace pool target specification "
                "needs to contain a subset of the same TVM "
                "targets as when specifying targets to use."
            )
    return target_attributes
コード例 #18
0
ファイル: pass_config.py プロジェクト: wang910/tvm
def get_pass_config_value(name, value, config_type):
    """Get a PassContext configuration value, based on its config data type.

    Parameters
    ----------
    name: str
        config identifier name.
    value: str
        value assigned to the config, provided via command line.
    config_type: str
        data type defined to the config, as string.

    Returns
    -------
    parsed_value: bool, int or str
        a representation of the input value, converted to the type
        specified by config_type.
    """

    if config_type == "IntImm":
        # "Bool" configurations in the PassContext are recognized as
        # IntImm, so deal with this case here
        mapping_values = {
            "false": False,
            "true": True,
        }

        if value.isdigit():
            parsed_value = int(value)
        else:
            # if not an int, accept only values on the mapping table, case insensitive
            parsed_value = mapping_values.get(value.lower(), None)

        if parsed_value is None:
            raise TVMCException(
                f"Invalid value '{value}' for configuration '{name}'. ")

    if config_type == "runtime.String":
        parsed_value = value

    return parsed_value
コード例 #19
0
ファイル: frontends.py プロジェクト: chenghanpeng/tvm
    def load(self, path, shape_dict=None, **kwargs):
        # pylint: disable=C0103
        tf = lazy_import("tensorflow")
        keras = lazy_import("keras", from_pkg_name="tensorflow")

        # tvm build currently imports keras directly instead of tensorflow.keras
        try:
            model = keras.models.load_model(path)
        except ValueError as err:
            raise TVMCException(str(err))

        # There are two flavours of keras model, sequential and
        # functional, TVM expects a functional model, so convert
        # if required:
        if self.is_sequential_p(model):
            model = self.sequential_to_functional(model)

        in_shapes = []
        for layer in model._input_layers:
            if tf.executing_eagerly():
                in_shapes.append(
                    tuple(dim if dim is not None else 1
                          for dim in layer.input.shape))
            else:
                in_shapes.append(
                    tuple(dim.value if dim.value is not None else 1
                          for dim in layer.input.shape))

        inputs = [
            np.random.uniform(size=shape, low=-1.0, high=1.0)
            for shape in in_shapes
        ]
        input_shapes = {
            name: x.shape
            for (name, x) in zip(model.input_names, inputs)
        }
        if shape_dict is not None:
            input_shapes.update(shape_dict)
        kwargs.setdefault("layout", "NHWC")
        return relay.frontend.from_keras(model, input_shapes, **kwargs)
コード例 #20
0
ファイル: composite_target.py プロジェクト: junrushao1994/tvm
def get_codegen_by_target(name):
    """Return a codegen entry by name.

    Parameters
    ----------
    name : str
        The name of the target for which the codegen info should be retrieved.

    Returns
    -------
    dict
        requested target codegen information
    """
    try:
        if name == "ethos-n78":
            warnings.warn(
                "Please use 'ethos-n' instead of the deprecated 'ethos-n78' target, "
                "which will be removed in a later release of TVM.",
                DeprecationWarning,
            )
        return REGISTERED_CODEGEN[name]
    except KeyError:
        raise TVMCException("Composite target %s is not defined in TVMC." % name)
コード例 #21
0
def convert_graph_layout(mod, desired_layout):
    """Alter the layout of the input graph.

    Parameters
    ----------
    mod : tvm.IRModule
        The relay module to convert.
    desired_layout : str
        The layout to convert to.

    Returns
    -------
    mod : tvm.IRModule
        The converted module.
    """

    # Assume for the time being that graphs only have
    # conv2d as heavily-sensitive operators.
    desired_layouts = {
        "nn.conv2d": [desired_layout, "default"],
        "nn.conv2d_transpose": [desired_layout, "default"],
        "qnn.conv2d": [desired_layout, "default"],
    }

    # Convert the layout of the graph where possible.
    seq = transform.Sequential([
        relay.transform.RemoveUnusedFunctions(),
        relay.transform.ConvertLayout(desired_layouts),
    ])

    with transform.PassContext(opt_level=3):
        try:
            return seq(mod)
        except Exception as err:
            raise TVMCException("Error converting layout to {0}: {1}".format(
                desired_layout, str(err)))
コード例 #22
0
ファイル: frontends.py プロジェクト: chenghanpeng/tvm
def get_frontend_by_name(name: str):
    """
    This function will try to get a frontend instance, based
    on the name provided.

    Parameters
    ----------
    name : str
        the name of a given frontend

    Returns
    -------
    frontend : tvm.driver.tvmc.Frontend
        An instance of the frontend that matches with
        the file extension provided in `path`.

    """

    for frontend in ALL_FRONTENDS:
        if name == frontend.name():
            return frontend()

    raise TVMCException("unrecognized frontend '{0}'. Choose from: {1}".format(
        name, get_frontend_names()))
コード例 #23
0
ファイル: arguments.py プロジェクト: zotanika/incubator-tvm
 def exit(self, status=0, message=None):
     # Don't exit on error when parsing the command line.
     # This won't catch all the errors generated when parsing tho. For instance, it won't catch
     # errors due to missing required arguments. But this will catch "error: invalid choice",
     # which is what it's necessary for its use in TVMC.
     raise TVMCException()
コード例 #24
0
    def import_package(self, package_path: str):
        """Load a TVMCPackage from a previously exported TVMCModel.

        Parameters
        ----------
        package_path : str
            The path to the saved TVMCPackage.
        """
        temp = self._tmp_dir
        t = tarfile.open(package_path)
        t.extractall(temp.relpath("."))

        if os.path.exists(temp.relpath("metadata.json")):
            # Model Library Format (MLF)
            self.lib_name = None
            self.lib_path = None
            with open(temp.relpath("metadata.json")) as metadata_json:
                metadata = json.load(metadata_json)

            all_module_names = []
            for name in metadata["modules"].keys():
                all_module_names.append(name)
            assert len(all_module_names
                       ) == 1, "Multiple modules in MLF is not supported."

            module_name = all_module_names[0]
            module_metdata = metadata["modules"][module_name]
            has_graph_executor = "graph" in module_metdata["executors"]
            graph = (temp.relpath(f"executor-config/graph/{module_name}.graph")
                     if has_graph_executor else None)
            params = temp.relpath(f"parameters/{module_name}.params")

            self.type = "mlf"
        else:
            # Classic format
            classic_lib_name_so = "mod.so"
            classic_lib_name_tar = "mod.tar"

            # VM format
            vm_lib_name_so = "lib.so"
            vm_lib_name_tar = "lib.tar"

            if os.path.exists(temp.relpath(classic_lib_name_so)):
                self.lib_name = classic_lib_name_so
                self.type = "classic"
            elif os.path.exists(temp.relpath(classic_lib_name_tar)):
                self.lib_name = classic_lib_name_tar
                self.type = "classic"
            elif os.path.exists(temp.relpath(vm_lib_name_so)):
                self.lib_name = vm_lib_name_so
                self.type = "vm"
            elif os.path.exists(temp.relpath(vm_lib_name_tar)):
                self.lib_name = vm_lib_name_tar
                self.type = "vm"
            else:
                raise TVMCException(
                    "Couldn't find exported library in the package.")

            self.lib_path = temp.relpath(self.lib_name)

            graph, params = None, None
            if self.type == "classic":
                graph = temp.relpath("mod.json")
                params = temp.relpath("mod.params")

        if params is not None:
            with open(params, "rb") as param_file:
                self.params = bytearray(param_file.read())
        else:
            self.params = None

        if graph is not None:
            with open(graph) as graph_file:
                self.graph = graph_file.read()
        else:
            self.graph = None
コード例 #25
0
def get_pass_config_value(name, value, config_type):
    """Get a PassContext configuration value, based on its config data type.

    Parameters
    ----------
    name: str
        config identifier name.
    value: str
        value assigned to the config, provided via command line.
    config_type: str
        data type defined to the config, as string.

    Returns
    -------
    parsed_value: bool, int or str
        a representation of the input value, converted to the type
        specified by config_type.
    """

    parsed_value = None

    if config_type == "IntImm":
        # "Bool" configurations in the PassContext are recognized as
        # IntImm, so deal with this case here
        mapping_values = {
            "false": False,
            "true": True,
        }

        if value.isdigit():
            parsed_value = int(value)
        else:
            # if not an int, accept only values on the mapping table, case insensitive
            parsed_value = mapping_values.get(value.lower(), None)

        if parsed_value is None:
            raise TVMCException(
                f"Invalid value '{value}' for configuration '{name}'.")

    elif config_type == "runtime.String":
        parsed_value = value

    elif config_type == "Array":
        if name == "tir.add_lower_pass":
            pass_list = value.split(",")
            if len(pass_list) % 2 != 0:
                raise TVMCException(
                    f"The configuration of '{name}' must be of the form "
                    "'tir.add_lower_pass=opt_level1,pass1,opt_evel2,pass2'")

            parsed_value = []
            for i in range(0, len(pass_list), 2):
                level, pass_func = pass_list[i].strip(), pass_list[i +
                                                                   1].strip()
                try:
                    level = int(level)
                except ValueError:
                    raise TVMCException(
                        f"Only integer is allow for configuration '{name}'.")

                # TODO (@leeexyz) We should parse configurations of each tir Pass.
                #     For now, we only use the defaults. Currently, There are four config nodes:
                #     `tir.transform.LoopPartitionConfig`
                #     `tir.transform.UnrollLoopConfig`
                #     `tir.transform.HoistIfThenElseConfig`
                #     `tir.transform.InjectDoubleBufferConfig`
                # loading pass func and calling it to get the Pass
                pass_func = load_function(pass_func)()
                parsed_value.append((level, pass_func))
        else:
            raise TVMCException(
                f"Unsupported configuration '{name}' for '{config_type}' type."
            )

    else:
        # not raise here cause we alreay checked before calling this function
        pass

    return parsed_value
コード例 #26
0
ファイル: target.py プロジェクト: a1nc/tvm
def target_from_cli(target, additional_target_options=None):
    """
    Create a tvm.target.Target instance from a
    command line interface (CLI) string.

    Parameters
    ----------
    target : str
        compilation target as plain string,
        inline JSON or path to a JSON file

    additional_target_options: Optional[Dict[str, Dict[str,str]]]
        dictionary of additional target options to be
        combined with parsed targets

    Returns
    -------
    tvm.target.Target
        an instance of target device information
    extra_targets : list of dict
        This list preserves the order in which extra targets were
        provided via command line. Each Dict contains three keys:
        'name', containing the name of the codegen; 'opts' containing
        a key-value for all options passed via CLI; 'raw',
        containing the plain string for this codegen
    """
    extra_targets = []

    if os.path.isfile(target):
        with open(target) as target_file:
            logger.debug("target input is a path: %s", target)
            target = "".join(target_file.readlines())
    elif is_inline_json(target):
        logger.debug("target input is inline JSON: %s", target)
    else:
        logger.debug("target input is plain text: %s", target)
        try:
            parsed_targets = parse_target(target)
        except ValueError as error:
            raise TVMCException(
                f"Error parsing target string '{target}'.\nThe error was: {error}"
            )

        validate_targets(parsed_targets, additional_target_options)
        tvm_targets = [
            _combine_target_options(t, additional_target_options)
            for t in parsed_targets if t["is_tvm_target"]
        ]

        # Validated target strings have 1 or 2 tvm targets, otherwise
        # `validate_targets` above will fail.
        if len(tvm_targets) == 1:
            target = _recombobulate_target(tvm_targets[0])
            target_host = None
        else:
            assert len(tvm_targets) == 2
            target = _recombobulate_target(tvm_targets[0])
            target_host = _recombobulate_target(tvm_targets[1])

        extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]]

    return tvm.target.Target(target, host=target_host), extra_targets