Пример #1
0
def apply_offline_transformations(input_model: str, framework: str, transforms: list, compress_fp16=False):
    # This variable is only needed by GenerateMappingFile transformation
    # to produce correct mapping
    extract_names = framework in ['tf', 'mxnet', 'kaldi']

    from openvino.offline_transformations import GenerateMappingFile, Serialize  # pylint: disable=import-error,no-name-in-module
    from openvino.inference_engine import IENetwork  # pylint: disable=import-error,no-name-in-module
    from ngraph.frontend import FrontEndManager, FrontEnd  # pylint: disable=no-name-in-module,import-error
    from ngraph.impl import Function  # pylint: disable=no-name-in-module,import-error

    fem = FrontEndManager()

    # We have to separate fe object lifetime from fem to
    # avoid segfault during object destruction. So fe must
    # be destructed before fem object explicitly.
    def read_network(path_to_xml):
        fe = fem.load_by_framework(framework="ir")
        f = fe.convert(fe.load(path_to_xml))
        return IENetwork(Function.to_capsule(f))

    net = read_network(input_model + "_tmp.xml")

    apply_user_transformations(net, transforms)
    apply_moc_transformations(net)

    if compress_fp16:
        compress_model(net)

    Serialize(net, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8'))
    path_to_mapping = input_model + ".mapping"
    GenerateMappingFile(net, path_to_mapping.encode('utf-8'), extract_names)
Пример #2
0
def apply_offline_transformations(input_model: str, framework: str,
                                  transforms: list):
    # This variable is only needed by GenerateMappingFile transformation
    # to produce correct mapping
    extract_names = framework in ['tf', 'mxnet', 'kaldi']

    from openvino.inference_engine import read_network  # pylint: disable=import-error,no-name-in-module
    from openvino.offline_transformations import GenerateMappingFile  # pylint: disable=import-error,no-name-in-module

    net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin")
    apply_moc_transformations(net, transforms)
    net.serialize(input_model + ".xml", input_model + ".bin")
    path_to_mapping = input_model + ".mapping"
    GenerateMappingFile(net, path_to_mapping.encode('utf-8'), extract_names)
Пример #3
0
def apply_offline_transformations(input_model: str, framework: str, transforms: list):
    # This variable is only needed by GenerateMappingFile transformation
    # to produce correct mapping
    extract_names = framework in ['tf', 'mxnet', 'kaldi']

    from openvino.inference_engine import read_network  # pylint: disable=import-error
    from openvino.offline_transformations import ApplyMOCTransformations, GenerateMappingFile  # pylint: disable=import-error

    net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin")

    available_transformations = get_available_transformations()

    for name, args in transforms:
        if name not in available_transformations.keys():
            raise Error("Transformation {} is not available.".format(name))

        available_transformations[name](net, **args)

    net.serialize(input_model + ".xml", input_model + ".bin")
    path_to_mapping = input_model + ".mapping"
    GenerateMappingFile(net, path_to_mapping.encode('utf-8'), extract_names)
Пример #4
0
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_model")
    parser.add_argument("--framework")
    args = parser.parse_args()
    path_to_model = args.input_model

    # This variable is only needed by GenerateMappingFile transformation
    # to produce correct mapping
    extract_names = True if args.framework in ['tf', 'mxnet', 'kaldi'
                                               ] else False

    try:
        from openvino.inference_engine import IECore, read_network  # pylint: disable=import-error
        from openvino.offline_transformations import ApplyMOCTransformations, GenerateMappingFile, CheckAPI  # pylint: disable=import-error
    except Exception as e:
        print("[ WARNING ] {}".format(e))
        exit(1)

    CheckAPI()

    net = read_network(path_to_model + "_tmp.xml", path_to_model + "_tmp.bin")
    net.serialize(path_to_model + ".xml", path_to_model + ".bin")
    path_to_mapping = path_to_model + ".mapping"
    GenerateMappingFile(net, path_to_mapping.encode('utf-8'), extract_names)