예제 #1
0
    def _f(graph_path: str, sample_shape: tuple,
           output_folder: str, report_folder: str = None):
        # Check whether pytorch is installed.
        if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
            error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and "
                                          f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) "
                                          f"are required when using graph based "
                                          f"scripts converter, and PyTorch version must "
                                          f"be consisted with model generation runtime.")
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        onnx = import_module("onnx")
        ort = import_module("onnxruntime")

        if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
                or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
            error = RuntimeIntegrityError(
                f"onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for Pytorch conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        func(graph_path=graph_path, sample_shape=sample_shape,
             output_folder=output_folder, report_folder=report_folder)
예제 #2
0
def onnx_lib_version_satisfied():
    """Check onnx libs version whether is satisfied."""
    onnx = import_module("onnx")
    ort = import_module("onnxruntime")
    optimizer = import_module("onnxoptimizer.version")
    if not lib_version_satisfied(getattr(ort, "__version__"),
                                 ONNXRUNTIME_MIN_VER):
        log_console.warning(
            "onnxruntime's version should be greater than %s, however current version is %s.",
            ONNXRUNTIME_MIN_VER, ort.__version__)

    if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
            or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER):
        return False
    return True
예제 #3
0
    def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
           output_folder: str, report_folder: str):
        not_integral_error = RuntimeIntegrityError(
            f"TensorFlow, "
            f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
            f"are required when using graph based scripts converter for TensorFlow conversion."
        )
        # Check whether tensorflow is installed.
        if not _check_tf_installation() or not onnx_satisfied():
            _print_error(not_integral_error)
            sys.exit(0)

        if not any([
                check_common_dependency_integrity("tensorflow"),
                check_common_dependency_integrity("tensorflow-gpu")
        ]):
            _print_error(not_integral_error)
            sys.exit(0)

        tf2onnx = import_module("tf2onnx")

        if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
                or not onnx_lib_version_satisfied():
            _print_error(not_integral_error)
            sys.exit(0)

        func(graph_path=graph_path,
             input_nodes=input_nodes,
             output_nodes=output_nodes,
             output_folder=output_folder,
             report_folder=report_folder)
예제 #4
0
def torch_version_satisfied(output_queue):
    """Check Torch version whether is satisfied."""
    satisfied = False
    pattern = r"\d+\.\d+\.\d+"
    torch_version = re.findall(pattern,
                               getattr(import_module('torch'), "__version__"))
    if torch_version:
        satisfied = lib_version_satisfied(torch_version[0], TORCH_MIN_VER)
    output_queue.put(satisfied)
예제 #5
0
    def _f(graph_path: str, sample_shape: tuple,
           output_folder: str, report_folder: str = None,
           input_nodes: str = None, output_nodes: str = None):
        # Check whether tensorflow is installed.
        if not _check_tf_installation() or not find_spec("tf2onnx") \
                or not find_spec("onnx") or not find_spec("onnxruntime"):
            error = RuntimeIntegrityError(
                f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for TensorFlow conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
        ort = import_module("onnxruntime")

        if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
                or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
                or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
            error = RuntimeIntegrityError(
                f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
                f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
                f"based scripts converter for TensorFlow conversion."
            )
            log.error(error)
            log_console.error("\n")
            log_console.error(str(error))
            log_console.error("\n")
            sys.exit(0)

        func(graph_path=graph_path, sample_shape=sample_shape,
             output_folder=output_folder, report_folder=report_folder,
             input_nodes=input_nodes, output_nodes=output_nodes)