예제 #1
0
    def convert(cls, op_name: str, params: Dict, weights: Dict = None):
        """
        Convert third party operation's param into MindSpore operation.

        Args:
            op_name (str): Operation name in ONNX.
            params (dict): Params in onnx.
            weights (dict): Weights in onnx.

        Returns:
            Tuple[str, dict, dict], operation name and params and settings.
        """
        global TABLE
        module_name = TABLE.get(op_name)

        if not module_name:
            return None, dict(), None, dict()

        pos = module_name.rfind(".")
        try:
            converter = getattr(importlib.import_module(module_name[:pos]),
                                module_name[pos + 1:])
            op_name_converter = getattr(converter, GET_OP_NAME)
            params_converter = getattr(converter, GET_OP_PARAMS)
            weights_converter = getattr(converter, GET_OP_WEIGHTS)
            template_generator = getattr(converter, GET_OP_TEMPLATE)
        except (ModuleNotFoundError,) as e:
            # If mapper can not be found, then skip it.
            err_msg = f"Converting {op_name} failed, see {str(e)}"
            log.error(err_msg)
            return None, None, None, None

        try:
            converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
            converted_params = params_converter(params=params, weights=weights)

            if "input_shape" in converted_params:
                converted_params.pop("input_shape")
            if "output_shape" in converted_params:
                converted_params.pop("output_shape")
            # set to converted_weights to enable weight migration
            converted_weights = weights_converter(weights=weights) if weights else dict()
            code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
                operation=converter_name,
                converted_params=converted_params,
                raw_params=params,
                weights=weights,
                trainable_params=converted_weights
            )

        except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
            err_msg = f"Converting {op_name} failed, see {str(e)}"
            log.error(err_msg)
            code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
                operation=op_name,
                params=params,
                weights=weights
            )

        return code_template, exchange_msg, outputs_list, outputs_mapping
예제 #2
0
    def create(cls, graph):
        """
        Factory method of hierarchical tree.

        Args:
            graph: Graph obj.

        Returns:
            HierarchicalTree, tree.
        """
        tree = HierarchicalTree()
        node_scope_name = dict()
        for _, node_name in enumerate(graph.nodes_in_topological_order):
            node_inst = graph.get_node(node_name)
            node_input = graph.get_input_shape(node_name)
            node_output = graph.get_output_shape(node_name)
            if not node_input:
                err_msg = f"This model is not supported now. " \
                          f"Cannot find {node_name}'s input shape."
                error = NodeInputMissing(err_msg)
                log.error(str(error))
                raise error
            if isinstance(node_inst, OnnxGraphNode):
                node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name)
                node_scope_name[node_name] = node_name_with_scope
                node_name = node_name_with_scope

            node_inst.add_input_and_output_shape(node_input, node_output)
            tree.insert(node_inst, node_name)

        if node_scope_name:
            return tree, node_scope_name
        return tree
예제 #3
0
    def _f(graph_path: str, sample_shape: tuple,
           output_folder: str, report_folder: str = None):
        # Check whether pytorch is installed.
        if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
            error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and "
                                          f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) "
                                          f"are required when using graph based "
                                          f"scripts converter, and PyTorch version must "
                                          f"be consisted with model generation runtime.")
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        onnx = import_module("onnx")
        ort = import_module("onnxruntime")

        if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
                or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
            error = RuntimeIntegrityError(
                f"onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for Pytorch conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        func(graph_path=graph_path, sample_shape=sample_shape,
             output_folder=output_folder, report_folder=report_folder)
예제 #4
0
    def initialize(self):
        """Initialize the OnnxDataLoader."""

        # check init conditions met
        if not self._check_initialization():
            err = ModuleNotFoundError("Unable to Find ONNX Model")
            log.error(str(err))
            log.exception(err)

        # 1. parse all nodes
        self._parse_nodes()

        # 2. parse value info (incl. node output shape)
        if self._is_infer_shape:
            try:
                self._infer_model()
                self._parse_value_info()
                self._parse_node_output_shape()
            except Exception as e:
                log.error(str(e))
                log.exception(e)
                raise e

        # 3. parse all tensors
        self._parse_tensors()

        # 4. Optimize graph to eliminate some nodes.
        self._find_nodes_to_be_eliminated()

        # 5. build nodes connections
        self.build_nodes_connection()

        # 6. Run onnx model to fetch actual value of eliminated nodes.
        self._fetch_eliminated_nodes_value()
    def parse(cls, model_path: str, **kwargs):
        """
        Parser pytorch graph.

        Args:
            model_path (str): Model file path.

        Returns:
            object, torch model.
        """
        if not os.path.exists(model_path):
            error = FileNotFoundError("`model_path` must be assigned with "
                                      "an existed file path.")
            log.error(str(error))
            raise error

        try:
            onnx_model_sim = cls._convert_pytorch_graph_to_onnx(
                model_path, kwargs['input_nodes'], opset_version=11)
            return onnx_model_sim
        except ModuleNotFoundError:
            error_msg = "Cannot find model scripts in system path, " \
                        "set `--project_path` to the path of model scripts folder correctly."
            error = ModuleNotFoundError(error_msg)
            raise error
예제 #6
0
def main_graph_base_converter(file_config):
    """
    The entrance for converter, script files will be converted.

    Args:
        file_config (dict): The config of file which to convert.

    """
    graph_path = file_config['model_file']
    frame_type = get_framework_type(graph_path)
    if frame_type == FrameworkType.PYTORCH.value:
        graph_based_converter_pytorch_to_ms(
            graph_path=graph_path,
            sample_shape=file_config['shape'],
            output_folder=file_config['outfile_dir'],
            report_folder=file_config['report_dir'])
    elif frame_type == FrameworkType.TENSORFLOW.value:
        check_params = ['input_nodes', 'output_nodes']
        check_params_exist(check_params, file_config)
        graph_based_converter_tf_to_ms(
            graph_path=graph_path,
            sample_shape=file_config['shape'],
            input_nodes=file_config['input_nodes'],
            output_nodes=file_config['output_nodes'],
            output_folder=file_config['outfile_dir'],
            report_folder=file_config['report_dir'])
    else:
        error_msg = "Get UNSUPPORTED model."
        error = UnknownModel(error_msg)
        log.error(str(error))
        raise error
예제 #7
0
    def _trace_torch_graph(self, input_shape):
        """
        Trace torch computational graph.

        Args:
            input_shape (tuple): Shape.

        Returns:
            object, pytorch graph.
        """
        import torch
        from torch.onnx import OperatorExportTypes
        from .torch_utils import OverloadTorchModuleTemporarily
        from .torch_utils import create_autograd_variable
        from .torch_utils import onnx_tracer

        batched_sample = create_autograd_variable(torch.rand(*input_shape))

        try:
            # Assign execution mode to eval.
            self.model.eval()

            with OverloadTorchModuleTemporarily() as _:
                # In pytorch higher version, trace function has a known.
                graph = onnx_tracer(self.model, batched_sample,
                                    OperatorExportTypes.ONNX)
            return graph
        except RuntimeError as error:
            log.error(str(error))
            log.exception(error)
            raise error
예제 #8
0
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir,
         report, project_path):
    """
    Run converter command.

    Args:
        in_files (str): The file path or directory to convert.
        model_file(str): The pytorch .pth to convert on graph based schema.
        shape(list): The input tensor shape of module_file.
        input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
        output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
        out_dir (str): The output directory to save converted file.
        report (str): The report file path.
        project_path(str): Pytorch scripts project path.
    """
    if in_files:
        files_config = {
            'root_path': in_files,
            'in_files': [],
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }

        if os.path.isfile(in_files):
            files_config['root_path'] = os.path.dirname(in_files)
            files_config['in_files'] = [in_files]
        else:
            for root_dir, _, files in os.walk(in_files):
                for file in files:
                    files_config['in_files'].append(
                        os.path.join(root_dir, file))
        main(files_config)
        log_console.info("\n")
        log_console.info("MindConverter: conversion is completed.")
        log_console.info("\n")

    elif model_file:
        file_config = {
            'model_file': model_file,
            'shape': shape if shape else [],
            'input_nodes': input_nodes,
            'output_nodes': output_nodes,
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }
        if project_path:
            paths = sys.path
            if project_path not in paths:
                sys.path.append(project_path)

        main_graph_base_converter(file_config)
        log_console.info("\n")
        log_console.info("MindConverter: conversion is completed.")
        log_console.info("\n")
    else:
        error_msg = "`--in_file` and `--model_file` should be set at least one."
        error = FileNotFoundError(error_msg)
        log.error(str(error))
        log.exception(error)
        raise error
예제 #9
0
    def parse(cls, model_path: str, **kwargs):
        """
        Parser pytorch graph.

        Args:
            model_path (str): Model file path.

        Returns:
            object, torch model.
        """
        torch = import_module("torch")

        if not os.path.exists(model_path):
            error = FileNotFoundError("`model_path` must be assigned with "
                                      "an existed file path.")
            log.error(str(error))
            raise error

        try:
            if torch.cuda.is_available():
                model = torch.load(f=model_path)
            else:
                model = torch.load(f=model_path, map_location="cpu")
        except ModuleNotFoundError:
            error_msg = "Cannot find model scripts in system path, " \
                        "set `--project_path` to the path of model scripts folder correctly."
            error = ModuleNotFoundError(error_msg)
            raise error

        return model
예제 #10
0
    def save_source_files(self, out_folder: str, mapper: Mapper,
                          model_name: str,
                          report_folder: str = None,
                          scope_name_map: dict = None) -> NoReturn:
        """
        Save source codes to target folder.

        Args:
            report_folder (str): Report folder.
            mapper (Mapper): Mapper of third party framework and mindspore.
            model_name(str): Name of Converted model.
            out_folder (str): Output folder.
            scope_name_map(str): Scope name map of tensorflow.

        """
        if scope_name_map:
            self._scope_name_map = scope_name_map
        try:
            self._adjust_structure()
            code_fragments = self._generate_codes(mapper)
        except (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError) as e:
            log.error("Error occur when generating codes.")
            raise e

        save_code_file_and_report(model_name, code_fragments, out_folder, report_folder)
예제 #11
0
    def parse(cls, model_path: str, **kwargs):
        """
        Parse TF Computational Graph File (.pb)

        Args:
            model_path (str): Model file path.

        Returns:
            object, ONNX model.
        """
        onnx_utils = import_module(
            "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils")
        convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx")
        tf_input_nodes = kwargs.get('input_nodes')
        tf_output_nodes = kwargs.get('output_nodes')
        if not os.path.exists(model_path):
            error = FileNotFoundError("`model_path` must be assigned with "
                                      "an existed file path.")
            log.error(str(error))
            raise error

        input_nodes = ",".join(tf_input_nodes.keys())
        output_nodes = ",".join(tf_output_nodes)
        invalid_inputs = TFGraphParser.invalid_nodes_name(input_nodes)
        invalid_outputs = TFGraphParser.invalid_nodes_name(output_nodes)
        if invalid_inputs:
            raise ModelLoadingError(f"Invalid Input Node Name Found: {', '.join(invalid_inputs)}")
        if invalid_outputs:
            raise ModelLoadingError(f"Invalid Output Node Name Found: {', '.join(invalid_outputs)}")

        model = convert_tf_graph_to_onnx(model_path,
                                         model_inputs=input_nodes,
                                         model_outputs=output_nodes)
        return model
    def initialize(self):
        """Initialize the OnnxDataLoader."""

        # Parse ONNX Graph level info
        self._parse_graph()

        # 1. parse all tensors
        self._parse_tensors()

        # 2. parse all nodes, note that parse tensors must be done as nodes require tensor info
        # to process the node weight sharing.
        self._parse_nodes()

        # 3. parse value info (incl. node output shape)
        if self._is_infer_shape:
            try:
                self._infer_model()
                self._parse_value_info()
                self._parse_node_output_shape()
            except Exception as e:
                log.error(str(e))
                log.exception(e)
                raise e

        # 4. Optimize graph to eliminate some nodes.
        self._find_nodes_to_be_eliminated()

        # 5. build nodes connections
        self.build_nodes_connection()

        # 6. Run onnx model to fetch actual value of eliminated nodes.
        self._fetch_eliminated_nodes_value()
예제 #13
0
 def load_metadata(**kwargs):
     """
     Load graph metadata.
     """
     err_msg = "class `PyTorchGraph` has not implemented " \
               "`load_metadata()`."
     log.error(err_msg)
     raise NotImplementedError(err_msg)
예제 #14
0
 def _f(arch, mapper):
     try:
         output = func(arch, mapper=mapper)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         log.error(msg)
         log.exception(e)
         raise error from e
     return output
예제 #15
0
 def _f(arch, model_path, **kwargs):
     try:
         output = func(arch, model_path=model_path, **kwargs)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         log.error(msg)
         log.exception(e)
         raise error from e
     return output
예제 #16
0
    def __new__(cls, *args, **kwargs):
        """Control the create action of graph."""
        model_param = args[0] if args else kwargs.get(
            cls._REQUIRED_PARAM_OF_MODEL)
        if not model_param:
            error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` "
                               f"can not be None.")
            log.error(str(error))
            log.exception(error)
            raise error

        return super(BaseGraph, cls).__new__(cls)
예제 #17
0
def check_params_exist(params: list, config):
    """Check params exist."""
    miss_param_list = ''
    for param in params:
        if not config.get(param) or not config[param]:
            miss_param_list = ', '.join(
                (miss_param_list, param)) if miss_param_list else param

    if miss_param_list:
        error = ParamMissError(miss_param_list)
        log.error(str(error))
        raise error
예제 #18
0
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir,
         report):
    """
    Run converter command.

    Args:
        in_files (str): The file path or directory to convert.
        model_file (str): The model to convert on graph based schema.
        shape (Sequence[tuple]): The input tensor shape of the model.
        input_nodes (Sequence[str]): The input node(s) name of model.
        output_nodes (Sequence[str]): The output node(s) name of model.
        out_dir (str): The output directory to save converted file.
        report (str): The report file path.
    """
    if in_files:
        files_config = {
            'root_path': in_files,
            'in_files': [],
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }
        if os.path.isfile(in_files):
            files_config['root_path'] = os.path.dirname(in_files)
            files_config['in_files'] = [in_files]
        else:
            for root_dir, _, files in os.walk(in_files):
                for file in files:
                    files_config['in_files'].append(
                        os.path.join(root_dir, file))
        main(files_config)
        log_console.info("MindConverter: conversion is completed.")

    elif model_file:
        file_config = {
            'model_file': model_file,
            'shape': shape if shape else [],
            'input_nodes': input_nodes,
            'output_nodes': output_nodes,
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }
        main_graph_base_converter(file_config)
        log_console.info("MindConverter: conversion is completed.")
    else:
        error_msg = "`--in_file` and `--model_file` should be set at least one."
        error = FileNotFoundError(error_msg)
        log.error(str(error))
        log_console.error(str(error))
        sys.exit(-1)
예제 #19
0
    def convert(cls, op_name: str, params: Dict, weights: Dict = None):
        """
        Convert third party operation's param into MindSpore operation.

        Args:
            op_name (str): Operation name in ONNX.
            params (dict): Params in onnx.
            weights (dict): Weights in onnx.

        Returns:
            Tuple[str, dict, dict], operation name and params and settings.
        """
        global TABLE
        module_name = TABLE.get(op_name)

        if not module_name:
            return None, dict(), None, dict()

        pos = module_name.rfind(".")
        try:
            converter = getattr(importlib.import_module(module_name[:pos]),
                                module_name[pos + 1:])
            op_name_converter = getattr(converter, GET_OP_NAME)
            params_converter = getattr(converter, GET_OP_PARAMS)
            weights_converter = getattr(converter, GET_OP_WEIGHTS)
            settings_converter = getattr(converter, GET_OP_SETTINGS)
        except (ModuleNotFoundError, ) as e:
            # If mapper can not be found, then skip it.
            err_msg = f"Converting {op_name} failed, see {str(e)}"
            log.error(err_msg)
            return None, dict(), None, dict()

        try:
            converter_name = op_name_converter(params=params,
                                               weights=weights,
                                               op_name=op_name)
            converted_params = params_converter(params=params, weights=weights)
            converted_weights = weights_converter(
                weights=weights) if weights else dict()
            converted_params.update(converted_weights)
            converted_settings = settings_converter(params=params,
                                                    weights=weights)
        except (AttributeError, KeyError, ValueError, TypeError,
                IndexError) as e:
            err_msg = f"Converting {op_name} failed, see {str(e)}"
            log.error(err_msg)
            return None, dict(), None, dict()

        return converter_name, converted_params, converted_settings, converted_weights
예제 #20
0
 def _f(*args, **kwargs):
     try:
         output = func(*args, **kwargs)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         error_code = e.error_code() if isinstance(e, MindConverterException) else None
         error.root_exception_error_code = error_code
         log.error(msg)
         log.exception(e)
         raise error
     except Exception as e:
         log.error(msg)
         log.exception(e)
         raise e
     return output
예제 #21
0
def save_intermediate_graph(dataloader, output_folder):
    """
    Save intermediate graph and topological order into output_folder.

    Args:
        dataloader (OnnxDataLoader): Dataloader inst.
        output_folder (str): Output folder path.
    """
    node_topo_order = []
    placeholder_width = 30
    for node_name, node in dataloader.nodes_dict.items():
        row = f"{node.op_type.ljust(placeholder_width)} {node_name}\n"
        node_topo_order.append(row)

    # Import onnx lib.
    onnx = import_module("onnx")

    out_folder = os.path.realpath(output_folder)
    if not os.path.exists(out_folder):
        os.makedirs(out_folder, RWX_MODE_FOR_OWNER)

    graph_file = os.path.join(out_folder, "graph.onnx")
    topological_order_file = os.path.join(out_folder, "topological_order.txt")

    if os.path.exists(topological_order_file):
        err_msg = f"{os.path.basename(topological_order_file)} already exists."
        log.error(err_msg)
        raise FileExistsError(err_msg)
    if os.path.exists(graph_file):
        err_msg = f"{os.path.basename(graph_file)} already exists."
        log.error(err_msg)
        raise FileExistsError(err_msg)

    # Write topological order to disk.
    with os.fdopen(os.open(topological_order_file, WRITE_FLAGS, stat.S_IRUSR),
                   "w") as topo_file:
        topo_file.writelines(node_topo_order)

    try:
        # Save graph to disk.
        onnx.save_model(dataloader.inferred_model, graph_file)
        os.chmod(graph_file, RW_MODE_FOR_OWNER)
    except (IOError, OSError, FileExistsError) as e:
        if os.path.exists(topological_order_file):
            os.remove(topological_order_file)
        if os.path.exists(graph_file):
            os.remove(graph_file)
        raise e
예제 #22
0
def get_framework_type(model_path):
    """Get framework type."""
    try:
        with open(model_path, 'rb') as f:
            if f.read(
                    BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
                framework_type = FrameworkType.PYTORCH.value
            else:
                framework_type = FrameworkType.TENSORFLOW.value
    except IOError:
        error_msg = "Get UNSUPPORTED model."
        error = UnknownModel(error_msg)
        log.error(str(error))
        raise error

    return framework_type
예제 #23
0
    def visit_Call(self, node):
        """Callback function when visit AST tree"""
        code = pasta.dump(node)
        api_name = pasta.dump(node.func)

        # The parent node first call is equal to this node, skip when parent node is replaced.
        # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
        # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
        # Access from the penultimate element in reverse order.
        for parent_node in self._stack[-2::-1]:
            if parent_node in self._new_call_nodes and pasta.dump(
                    parent_node).startswith(api_name):
                return
        parent = self._stack[-2]
        new_node = None
        new_code = code
        matched_api_name, match_case = self.match_api(
            node.func, self._is_forward_function)
        if match_case in [
                ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED
        ]:
            new_node = self._convert_call(node, matched_api_name)
        elif match_case in [
                ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND
        ]:
            self._process_log.warning(node.lineno, node.col_offset,
                                      LOG_FMT_NOT_CONVERT % (api_name, ''))
        else:
            pass

        if parent and new_node:
            update_line_col = _LineColEditVisitor()
            update_line_col.update(new_node, node)
            pasta.ast_utils.replace_child(parent, node, new_node)
            self._new_call_nodes.append(new_node)

            node = new_node
            self._stack[-1] = node
        try:
            self.generic_visit(node)
        except Exception:
            logger.error('original code:%s, new code:%s',
                         code,
                         new_code,
                         exc_info=True)
            raise
예제 #24
0
    def _f(graph_path: str,
           sample_shape: tuple,
           output_folder: str,
           report_folder: str = None):
        # Check whether pytorch is installed.
        if not find_spec("torch"):
            error = ModuleNotFoundError(
                "PyTorch is required when using graph based "
                "scripts converter, and PyTorch vision must "
                "be consisted with model generation runtime.")
            log.error(str(error))
            log.exception(error)
            raise error

        func(graph_path=graph_path,
             sample_shape=sample_shape,
             output_folder=output_folder,
             report_folder=report_folder)
예제 #25
0
 def _f(graph_path: str,
        sample_shape: tuple,
        output_folder: str,
        report_folder: str = None,
        input_nodes: str = None,
        output_nodes: str = None):
     # Check whether tensorflow is installed.
     if not find_spec("tensorflow") or not find_spec("tf2onnx"):
         error = ModuleNotFoundError(
             "Tensorflow and tf2onnx are required when using "
             "graph based scripts converter.")
         log.error(str(error))
         raise error
     func(graph_path=graph_path,
          sample_shape=sample_shape,
          output_folder=output_folder,
          report_folder=report_folder,
          input_nodes=input_nodes,
          output_nodes=output_nodes)
예제 #26
0
    def _check_input_shape(input_shape):
        """
        Check input shape.

        Args:
            input_shape (tuple): Input tensor shape.

        """
        if not input_shape:
            err_msg = "`input_shape` can not be None."
            log.error(err_msg)
            raise ValueError(err_msg)

        for item in input_shape:
            if not isinstance(item, int):
                err_msg = "Only support model with one input now, " \
                          "and each shape value in `input_shape` should be int."
                log.error(err_msg)
                raise ValueError(err_msg)
예제 #27
0
def get_framework_type(model_path):
    """Get framework type."""
    try:
        with open(model_path, 'rb') as f:
            if f.read(
                    BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
                framework_type = FrameworkType.PYTORCH.value
            elif os.path.basename(model_path).split(
                    ".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX:
                framework_type = FrameworkType.TENSORFLOW.value
            else:
                framework_type = FrameworkType.UNKNOWN.value
    except IOError:
        error_msg = "Get UNSUPPORTED model."
        error = UnknownModelError(error_msg)
        log.error(str(error))
        raise error

    return framework_type
예제 #28
0
    def normalize_dict_key(d):
        """
        Normalize dictionary key.

        Note:
            The normalization is removing :0 in each node or output name.

        Args:
            d (dict): Dictionary where keys are node/output names.

        Returns:
            dict, normalized dictionary.
        """
        if not isinstance(d, (dict, OrderedDict)):
            error_msg = "Error occurs in normalizing dictionary key.\
                        Object passed in is not a dictionary."

            error = TypeError(error_msg)
            log.error(error_msg)
            log.exception(error)
            raise error

        new_d = None
        if isinstance(d, dict):
            new_d = {}
            for key_old in d.keys():
                key_new = key_old.split(':')[0]
                new_d[key_new] = d.get(key_old)

        if isinstance(d, OrderedDict):
            new_d = OrderedDict()
            for key_old in d.keys():
                key_new = key_old.split(':')[0]
                new_d[key_new] = d.get(key_old)

        if not new_d:
            error_msg = "Error occurs in normalizing dictionary key."
            error = ValueError(error_msg)
            log.error(error_msg)
            log.exception(error)
            raise error
        return new_d
예제 #29
0
 def _f(*args, **kwargs):
     try:
         res = func(*args, **kwargs)
     except cls.raise_from() as e:
         error = cls() if not msg else cls(msg=msg)
         detail_info = str(e)
         if not isinstance(e, MindConverterException):
             detail_info = cls.normalize_error_msg(str(e))
         log.error(error)
         log_console.error("\n")
         log_console.error(detail_info)
         log_console.error("\n")
         log.exception(e)
         sys.exit(0)
     except ModuleNotFoundError as e:
         detail_info = "Error detail: Required package not found, please check the runtime environment."
         log_console.error("\n")
         log_console.error(str(e))
         log_console.error(detail_info)
         log_console.error("\n")
         log.exception(e)
         sys.exit(0)
     return res
예제 #30
0
    def _f(graph_path: str, sample_shape: tuple,
           output_folder: str, report_folder: str = None,
           input_nodes: str = None, output_nodes: str = None):
        # Check whether tensorflow is installed.
        if not _check_tf_installation() or not find_spec("tf2onnx") \
                or not find_spec("onnx") or not find_spec("onnxruntime"):
            error = RuntimeIntegrityError(
                f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for TensorFlow conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
        ort = import_module("onnxruntime")

        if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
                or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
                or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
            error = RuntimeIntegrityError(
                f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for TensorFlow conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        func(graph_path=graph_path, sample_shape=sample_shape,
             output_folder=output_folder, report_folder=report_folder,
             input_nodes=input_nodes, output_nodes=output_nodes)