def _check_data(self, data_1, data_2):
     """Check the shape of two inputs."""
     if len(data_1) != len(data_2):
         logger_console.error(
             f"The construct of {self._source_script_name} and that of {self._target_script_name} are not matched."
         )
         exit(0)
Exemplo n.º 2
0
 def _f(file_config):
     try:
         func(file_config=file_config)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         detail_info = f"Error detail: {str(e)}"
         log_console.error(str(error))
         log_console.error(detail_info)
         log.exception(e)
         sys.exit(-1)
Exemplo n.º 3
0
 def _f(graph_path, sample_shape, output_folder, report_folder):
     try:
         func(graph_path=graph_path,
              sample_shape=sample_shape,
              output_folder=output_folder,
              report_folder=report_folder)
     except cls.raise_from() as e:
         error = cls(msg=msg)
         detail_info = f"Error detail: {str(e)}"
         log_console.error(str(error))
         log_console.error(detail_info)
         log.exception(e)
         sys.exit(-1)
Exemplo n.º 4
0
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir,
         report):
    """
    Run converter command.

    Args:
        in_files (str): The file path or directory to convert.
        model_file (str): The model to convert on graph based schema.
        shape (Sequence[tuple]): The input tensor shape of the model.
        input_nodes (Sequence[str]): The input node(s) name of model.
        output_nodes (Sequence[str]): The output node(s) name of model.
        out_dir (str): The output directory to save converted file.
        report (str): The report file path.
    """
    if in_files:
        files_config = {
            'root_path': in_files,
            'in_files': [],
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }
        if os.path.isfile(in_files):
            files_config['root_path'] = os.path.dirname(in_files)
            files_config['in_files'] = [in_files]
        else:
            for root_dir, _, files in os.walk(in_files):
                for file in files:
                    files_config['in_files'].append(
                        os.path.join(root_dir, file))
        main(files_config)
        log_console.info("MindConverter: conversion is completed.")

    elif model_file:
        file_config = {
            'model_file': model_file,
            'shape': shape if shape else [],
            'input_nodes': input_nodes,
            'output_nodes': output_nodes,
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }
        main_graph_base_converter(file_config)
        log_console.info("MindConverter: conversion is completed.")
    else:
        error_msg = "`--in_file` and `--model_file` should be set at least one."
        error = FileNotFoundError(error_msg)
        log.error(str(error))
        log_console.error(str(error))
        sys.exit(-1)
Exemplo n.º 5
0
 def _f(*args, **kwargs):
     try:
         res = func(*args, **kwargs)
     except cls.raise_from() as e:
         error = cls() if not msg else cls(msg=msg)
         detail_info = str(e)
         only_console = False
         if not isinstance(e, MindConverterException):
             detail_info = cls.normalize_error_msg(str(e))
         else:
             only_console = e.only_console
         log_console.error(detail_info, only_console=only_console)
         if not only_console:
             log.error(error)
             log.exception(e)
             log_console.warning(get_lib_notice_info())
         sys.exit(-1)
     except ModuleNotFoundError as e:
         detail_info = "Error detail: Required package not found, please check the runtime environment."
         log_console.error(f"{str(e)}\n{detail_info}")
         log.exception(e)
         log_console.warning(get_lib_notice_info())
         sys.exit(-1)
     return res
Exemplo n.º 6
0
 def _f(*args, **kwargs):
     try:
         res = func(*args, **kwargs)
     except cls.raise_from() as e:
         error = cls() if not msg else cls(msg=msg)
         detail_info = str(e)
         if not isinstance(e, MindConverterException):
             detail_info = cls.normalize_error_msg(str(e))
         log.error(error)
         log_console.error("\n")
         log_console.error(detail_info)
         log_console.error("\n")
         log.exception(e)
         sys.exit(0)
     except ModuleNotFoundError as e:
         detail_info = "Error detail: Required package not found, please check the runtime environment."
         log_console.error("\n")
         log_console.error(str(e))
         log_console.error(detail_info)
         log_console.error("\n")
         log.exception(e)
         sys.exit(0)
     return res
Exemplo n.º 7
0
def _print_error(err):
    """Print error to stdout and record it."""
    log.error(err)
    log_console.error(str(err))
Exemplo n.º 8
0
    def _f(graph_path: str, sample_shape: tuple,
           output_folder: str, report_folder: str = None):
        # Check whether pytorch is installed.
        if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
            error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and "
                                          f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) "
                                          f"are required when using graph based "
                                          f"scripts converter, and PyTorch version must "
                                          f"be consisted with model generation runtime.")
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        onnx = import_module("onnx")
        ort = import_module("onnxruntime")

        if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
                or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
            error = RuntimeIntegrityError(
                f"onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for Pytorch conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        func(graph_path=graph_path, sample_shape=sample_shape,
             output_folder=output_folder, report_folder=report_folder)
Exemplo n.º 9
0
    def _f(graph_path: str, sample_shape: tuple,
           output_folder: str, report_folder: str = None,
           input_nodes: str = None, output_nodes: str = None):
        # Check whether tensorflow is installed.
        if not _check_tf_installation() or not find_spec("tf2onnx") \
                or not find_spec("onnx") or not find_spec("onnxruntime"):
            error = RuntimeIntegrityError(
                f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for TensorFlow conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
        ort = import_module("onnxruntime")

        if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
                or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
                or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
            error = RuntimeIntegrityError(
                f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for TensorFlow conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        func(graph_path=graph_path, sample_shape=sample_shape,
             output_folder=output_folder, report_folder=report_folder,
             input_nodes=input_nodes, output_nodes=output_nodes)
Exemplo n.º 10
0
            Default output file is saved in the current working directory, with the same name as `fixed_py_file`.
        """)

if __name__ == '__main__':

    argv = sys.argv[1:]
    if not argv:
        argv = ['-h']
        args = parser.parse_args(argv)
    else:
        args = parser.parse_args()

    source_py_file = args.source_py_file
    source_ckpt_file = args.source_ckpt_file
    fixed_py_file = args.fixed_py_file
    fixed_ckpt_file = args.fixed_ckpt_file

    fixed_ckpt_file = extract_fixed_ckpt_file(
        fixed_ckpt_file,
        os.path.basename(fixed_py_file).replace(".py", ""))

    if not source_checker(source_py_file, source_ckpt_file):
        logger_console.error(
            "source checkpoint file is not inconsistent with source model script."
        )
        exit(-1)

    fix_checkpoint_generator = FixCheckPointGenerator(source_py_file,
                                                      fixed_py_file)
    fix_checkpoint_generator.fix_ckpt(source_ckpt_file, fixed_ckpt_file)
Exemplo n.º 11
0
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir,
         report, project_path):
    """
    Run converter command.

    Args:
        in_files (str): The file path or directory to convert.
        model_file(str): The pytorch .pth to convert on graph based schema.
        shape(list): The input tensor shape of module_file.
        input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
        output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
        out_dir (str): The output directory to save converted file.
        report (str): The report file path.
        project_path(str): Pytorch scripts project path.
    """
    if in_files:
        files_config = {
            'root_path': in_files,
            'in_files': [],
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }

        if os.path.isfile(in_files):
            files_config['root_path'] = os.path.dirname(in_files)
            files_config['in_files'] = [in_files]
        else:
            for root_dir, _, files in os.walk(in_files):
                for file in files:
                    files_config['in_files'].append(
                        os.path.join(root_dir, file))
        main(files_config)
        log_console.info("\n")
        log_console.info("MindConverter: conversion is completed.")
        log_console.info("\n")

    elif model_file:
        file_config = {
            'model_file': model_file,
            'shape': shape if shape else [],
            'input_nodes': input_nodes,
            'output_nodes': output_nodes,
            'outfile_dir': out_dir,
            'report_dir': report if report else out_dir
        }
        if project_path:
            paths = sys.path
            if project_path not in paths:
                sys.path.append(project_path)

        main_graph_base_converter(file_config)
        log_console.info("\n")
        log_console.info("MindConverter: conversion is completed.")
        log_console.info("\n")
    else:
        error_msg = "`--in_file` and `--model_file` should be set at least one."
        error = FileNotFoundError(error_msg)
        log.error(str(error))
        log_console.error("\n")
        log_console.error("mindconverter: error: %s", str(error))
        log_console.error("\n")
        sys.exit(-1)