예제 #1
0
파일: mo_tf.py 프로젝트: pc2/CustoNN2
#!/usr/bin/env python3
"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import sys

from mo.utils.versions_checker import check_python_version

if __name__ == "__main__":
    ret_code = check_python_version()
    if ret_code:
        sys.exit(ret_code)

    from mo.main import main
    from mo.utils.cli_parser import get_tf_cli_parser

    sys.exit(main(get_tf_cli_parser(), 'tf'))
def convert_to_openvino(args, input_dims, graph_chars):
    if args.transformations_config is None:
        print(
            "Error:--transformations_config args are required for openvino conversion"
        )
        return

    if args.openvino_dir is None:
        openvino_dir = os.getenv("INTEL_OPENVINO_DIR")
        if openvino_dir is None:
            print(
                "Could not find an OpenVINO installation. Assuming location in /opt/intel/openvino, but check that"
                "OpenVINO is installed")
            openvino_dir = "/opt/intel/openvino"
    else:
        openvino_dir = args.openvino_dir

    sys.path.insert(1, openvino_dir + "/deployment_tools/model_optimizer")
    from mo.main import main
    from mo.utils.cli_parser import get_tf_cli_parser

    sys.argv = ['']

    # Set input model
    sys.argv.append("--input_model")
    sys.argv.append(args.input)

    # Set transformation config
    sys.argv.append("--transformations_config")
    sys.argv.append(args.transformations_config)

    # Set pipeline
    if args.pipeline_config is None:
        sp = args.input.rsplit('/', 1)
        if len(sp) == 1:
            localdir = './'
        else:
            localdir = sp[0] + '/'
        pipelines = glob.glob(localdir + '*.config')
        if len(pipelines) != 1:
            print("Error: No clear pipeline file")
            exit(1)
        args.pipeline_config = pipelines[0]
    sys.argv.append("--tensorflow_object_detection_api_pipeline_config")
    sys.argv.append(args.pipeline_config)

    # Set input dimensions
    sys.argv.append("--input_shape")
    sys.argv.append(str(input_dims))

    # Check reversal
    if args.channel_order == "RGB":
        sys.argv.append("--reverse_input_channels")

    # Set output dir
    sys.argv.append("--output_dir")
    sys.argv.append(args.output_dir.rsplit('/', 1)[0] + '/')

    # Set output nodes
    sys.argv.append("--output")
    sys.argv.append(','.join([node.name for node in graph_chars.output_nodes]))

    main(get_tf_cli_parser(), 'tf')