def initialize(self):
        """Initialize the OnnxDataLoader."""

        # Parse ONNX Graph level info
        self._parse_graph()

        # 1. parse all tensors
        self._parse_tensors()

        # 2. parse all nodes, note that parse tensors must be done as nodes require tensor info
        # to process the node weight sharing.
        self._parse_nodes()

        # 3. parse value info (incl. node output shape)
        if self._is_infer_shape:
            try:
                self._infer_model()
                self._parse_value_info()
                self._parse_node_output_shape()
            except Exception as e:
                log.error(str(e))
                log.exception(e)
                raise e

        # 4. Optimize graph to eliminate some nodes.
        self._find_nodes_to_be_eliminated()

        # 5. build nodes connections
        self.build_nodes_connection()

        # 6. Run onnx model to fetch actual value of eliminated nodes.
        self._fetch_eliminated_nodes_value()
예제 #2
0
    def _trace_torch_graph(self, input_shape):
        """
        Trace torch computational graph.

        Args:
            input_shape (tuple): Shape.

        Returns:
            object, pytorch graph.
        """
        import torch
        from torch.onnx import OperatorExportTypes
        from .torch_utils import OverloadTorchModuleTemporarily
        from .torch_utils import create_autograd_variable
        from .torch_utils import onnx_tracer

        batched_sample = create_autograd_variable(torch.rand(*input_shape))

        try:
            # Assign execution mode to eval.
            self.model.eval()

            with OverloadTorchModuleTemporarily() as _:
                # In pytorch higher version, trace function has a known.
                graph = onnx_tracer(self.model, batched_sample,
                                    OperatorExportTypes.ONNX)
            return graph
        except RuntimeError as error:
            log.error(str(error))
            log.exception(error)
            raise error
예제 #3
0
    def initialize(self):
        """Initialize the OnnxDataLoader."""

        # check init conditions met
        if not self._check_initialization():
            err = ModuleNotFoundError("Unable to Find ONNX Model")
            log.error(str(err))
            log.exception(err)

        # 1. parse all nodes
        self._parse_nodes()

        # 2. parse value info (incl. node output shape)
        if self._is_infer_shape:
            try:
                self._infer_model()
                self._parse_value_info()
                self._parse_node_output_shape()
            except Exception as e:
                log.error(str(e))
                log.exception(e)
                raise e

        # 3. parse all tensors
        self._parse_tensors()

        # 4. Optimize graph to eliminate some nodes.
        self._find_nodes_to_be_eliminated()

        # 5. build nodes connections
        self.build_nodes_connection()

        # 6. Run onnx model to fetch actual value of eliminated nodes.
        self._fetch_eliminated_nodes_value()
예제 #4
0
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir,
         report, project_path):
    """
    Run converter command.

    Args:
        in_files (str): The file path or directory to convert.
        model_file(str): The pytorch .pth to convert on graph based schema.
        shape(list): The input tensor shape of module_file.
        input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
        output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
        out_dir (str): The output directory to save converted file.
        report (str): The report file path.
        project_path(str): Pytorch scripts project path.
    """
    if in_files:
        files_config = {
            'root_path': in_files,
            'in_files': [],
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }

        if os.path.isfile(in_files):
            files_config['root_path'] = os.path.dirname(in_files)
            files_config['in_files'] = [in_files]
        else:
            for root_dir, _, files in os.walk(in_files):
                for file in files:
                    files_config['in_files'].append(
                        os.path.join(root_dir, file))
        main(files_config)
        log_console.info("\n")
        log_console.info("MindConverter: conversion is completed.")
        log_console.info("\n")

    elif model_file:
        file_config = {
            'model_file': model_file,
            'shape': shape if shape else [],
            'input_nodes': input_nodes,
            'output_nodes': output_nodes,
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }
        if project_path:
            paths = sys.path
            if project_path not in paths:
                sys.path.append(project_path)

        main_graph_base_converter(file_config)
        log_console.info("\n")
        log_console.info("MindConverter: conversion is completed.")
        log_console.info("\n")
    else:
        error_msg = "`--in_file` and `--model_file` should be set at least one."
        error = FileNotFoundError(error_msg)
        log.error(str(error))
        log.exception(error)
        raise error
예제 #5
0
def check_dependency_integrity(*packages):
    """Check dependency package integrity."""
    try:
        for pkg in packages:
            import_module(pkg)
        return True
    except ImportError as e:
        log.exception(e)
        return False
예제 #6
0
 def _f(arch, mapper):
     try:
         output = func(arch, mapper=mapper)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         log.error(msg)
         log.exception(e)
         raise error from e
     return output
예제 #7
0
 def _f(arch, model_path, **kwargs):
     try:
         output = func(arch, model_path=model_path, **kwargs)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         log.error(msg)
         log.exception(e)
         raise error from e
     return output
예제 #8
0
 def _f(file_config):
     try:
         func(file_config=file_config)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         detail_info = f"Error detail: {str(e)}"
         log_console.error(str(error))
         log_console.error(detail_info)
         log.exception(e)
         sys.exit(-1)
예제 #9
0
    def save_source_files(self, out_folder: str, mapper: Mapper,
                          model_name: str,
                          report_folder: str = None,
                          scope_name_map: dict = None) -> NoReturn:
        """
        Save source codes to target folder.

        Args:
            report_folder (str): Report folder.
            mapper (Mapper): Mapper of third party framework and mindspore.
            model_name(str): Name of Converted model.
            out_folder (str): Output folder.
            scope_name_map(str): Scope name map of tensorflow.

        """
        if scope_name_map:
            self._scope_name_map = scope_name_map
        try:
            self._adjust_structure()
            code_fragments = self._generate_codes(mapper)
        except (NodeInputTypeNotSupport, ScriptGenerateFail, ReportGenerateFail) as e:
            log.error("Error occur when generating codes.")
            raise e

        out_folder = os.path.realpath(out_folder)
        if not report_folder:
            report_folder = out_folder
        else:
            report_folder = os.path.realpath(report_folder)

        if not os.path.exists(out_folder):
            os.makedirs(out_folder, self.modes_usr)
        if not os.path.exists(report_folder):
            os.makedirs(report_folder, self.modes_usr)

        for file_name in code_fragments:
            code, report = code_fragments[file_name]
            try:
                with os.fdopen(os.open(os.path.realpath(os.path.join(out_folder, f"{model_name}.py")),
                                       self.flags, self.modes), 'w') as file:
                    file.write(code)
            except IOError as error:
                log.error(str(error))
                log.exception(error)
                raise error

            try:
                with os.fdopen(os.open(os.path.realpath(os.path.join(report_folder,
                                                                     f"report_of_{model_name}.txt")),
                                       self.flags, stat.S_IRUSR), "w") as rpt_f:
                    rpt_f.write(report)
            except IOError as error:
                log.error(str(error))
                log.exception(error)
                raise error
예제 #10
0
    def __new__(cls, *args, **kwargs):
        """Control the create action of graph."""
        model_param = args[0] if args else kwargs.get(
            cls._REQUIRED_PARAM_OF_MODEL)
        if not model_param:
            error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` "
                               f"can not be None.")
            log.error(str(error))
            log.exception(error)
            raise error

        return super(BaseGraph, cls).__new__(cls)
예제 #11
0
 def _f(graph_path, sample_shape, output_folder, report_folder):
     try:
         func(graph_path=graph_path,
              sample_shape=sample_shape,
              output_folder=output_folder,
              report_folder=report_folder)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         detail_info = f"Error detail: {str(e)}"
         log_console.error(str(error))
         log_console.error(detail_info)
         log.exception(e)
         sys.exit(-1)
예제 #12
0
 def _f(*args, **kwargs):
     try:
         output = func(*args, **kwargs)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         error_code = e.error_code() if isinstance(e, MindConverterException) else None
         error.root_exception_error_code = error_code
         log.error(msg)
         log.exception(e)
         raise error
     except Exception as e:
         log.error(msg)
         log.exception(e)
         raise e
     return output
예제 #13
0
    def _f(graph_path: str,
           sample_shape: tuple,
           output_folder: str,
           report_folder: str = None):
        # Check whether pytorch is installed.
        if not find_spec("torch"):
            error = ModuleNotFoundError(
                "PyTorch is required when using graph based "
                "scripts converter, and PyTorch vision must "
                "be consisted with model generation runtime.")
            log.error(str(error))
            log.exception(error)
            raise error

        func(graph_path=graph_path,
             sample_shape=sample_shape,
             output_folder=output_folder,
             report_folder=report_folder)
예제 #14
0
    def normalize_dict_key(d):
        """
        Normalize dictionary key.

        Note:
            The normalization is removing :0 in each node or output name.

        Args:
            d (dict): Dictionary where keys are node/output names.

        Returns:
            dict, normalized dictionary.
        """
        if not isinstance(d, (dict, OrderedDict)):
            error_msg = "Error occurs in normalizing dictionary key.\
                        Object passed in is not a dictionary."

            error = TypeError(error_msg)
            log.error(error_msg)
            log.exception(error)
            raise error

        new_d = None
        if isinstance(d, dict):
            new_d = {}
            for key_old in d.keys():
                key_new = key_old.split(':')[0]
                new_d[key_new] = d.get(key_old)

        if isinstance(d, OrderedDict):
            new_d = OrderedDict()
            for key_old in d.keys():
                key_new = key_old.split(':')[0]
                new_d[key_new] = d.get(key_old)

        if not new_d:
            error_msg = "Error occurs in normalizing dictionary key."
            error = ValueError(error_msg)
            log.error(error_msg)
            log.exception(error)
            raise error
        return new_d
예제 #15
0
 def _f(*args, **kwargs):
     try:
         res = func(*args, **kwargs)
     except cls.raise_from() as e:
         error = cls() if not msg else cls(msg=msg)
         detail_info = str(e)
         if not isinstance(e, MindConverterException):
             detail_info = cls.normalize_error_msg(str(e))
         log.error(error)
         log_console.error("\n")
         log_console.error(detail_info)
         log_console.error("\n")
         log.exception(e)
         sys.exit(0)
     except ModuleNotFoundError as e:
         detail_info = "Error detail: Required package not found, please check the runtime environment."
         log_console.error("\n")
         log_console.error(str(e))
         log_console.error(detail_info)
         log_console.error("\n")
         log.exception(e)
         sys.exit(0)
     return res
예제 #16
0
 def _f(*args, **kwargs):
     try:
         res = func(*args, **kwargs)
     except cls.raise_from() as e:
         error = cls() if not msg else cls(msg=msg)
         detail_info = str(e)
         only_console = False
         if not isinstance(e, MindConverterException):
             detail_info = cls.normalize_error_msg(str(e))
         else:
             only_console = e.only_console
         log_console.error(detail_info, only_console=only_console)
         if not only_console:
             log.error(error)
             log.exception(e)
             log_console.warning(get_lib_notice_info())
         sys.exit(-1)
     except ModuleNotFoundError as e:
         detail_info = "Error detail: Required package not found, please check the runtime environment."
         log_console.error(f"{str(e)}\n{detail_info}")
         log.exception(e)
         log_console.warning(get_lib_notice_info())
         sys.exit(-1)
     return res
예제 #17
0
    def initialize(self):
        """Initialize the OnnxDataLoader."""

        # check init conditions met
        if not self._check_initialization():
            err = ModuleNotFoundError("Unable to Find ONNX Model")
            log.error(str(err))
            log.exception(err)

        # 1. parse all nodes
        self._parse_nodes()

        # 2. parse value info (incl. node output shape)
        if self._is_infer_shape:
            try:
                self._infer_model()
            except Exception as e:
                log.error(str(e))
                log.exception(e)
                raise e

        if self.inferred_model:
            try:
                self._parse_value_info()
            except Exception as e:
                log.error(str(e))
                log.exception(e)
                raise e

            try:
                self._parse_node_output_shape()
            except Exception as e:
                log.error(str(e))
                log.exception(e)
                raise e

        # 3. parse all tensors
        self._parse_tensors()

        # 4. build nodes connections
        self.build_nodes_connection()