def test_input_shape(self, mock_argparse):
        main(argparse.ArgumentParser(), fem, 'ov_mock_mo_frontend')
        stat = get_model_statistic()

        # verify that 'set_partial_shape' was called
        assert stat.set_partial_shape == 1
        assert stat.lastArgPartialShape == PartialShape([1, 2, 3, 4])
    def test_element_type(self, mock_argparse):
        main(argparse.ArgumentParser(), fem, 'ov_mock_mo_frontend')
        stat = get_model_statistic()

        # verify that 'set_element_type' was called
        assert stat.set_element_type == 1
        assert stat.lastArgElementType == get_element_type(np.int8)
def convert(model, **args):
    mo.main.prepare_ir = _prepare_ir

    parser = get_common_cli_parser()
    parser.set_defaults(input_model=model,
                        extensions=os.path.join(os.path.dirname(__file__),
                                                'mo_extensions'),
                        ie_is_available=False)
    for arg, value in args.items():
        parser.set_defaults(**{arg: str(value)})
    parser.set_defaults(is_dynamic=args.get("is_dynamic", True))

    # Replace original parser to ignore global sys.argv
    origin_parse = parser.parse_args
    parser.parse_args = lambda: origin_parse([])

    err = None
    try:
        err = main(parser, None, 'pytorch')
    except:
        if err is None:
            mo.main.prepare_ir = lambda argv: _prepare_ir(argv, old_api=True)
            err = main(parser, 'pytorch')
    if err:
        raise Exception('model conversion failed')
    def test_extract_subgraph(self, mock_argparse):
        main(argparse.ArgumentParser(), fem, 'ov_mock_mo_frontend')
        stat = get_model_statistic()

        # verify that 'extract_subgraph' was called
        assert stat.override_all_inputs == 0
        assert stat.override_all_outputs == 0
        assert stat.extract_subgraph == 1
    def test_set_batch_size(self, mock_argparse):
        mock_return_partial_shape(PartialShape([-1, 2, 3, 4]))
        main(argparse.ArgumentParser(), fem, 'ov_mock_mo_frontend')
        stat = get_model_statistic()

        # verify that 'set_element_type' was called
        # 2 is because mock model has 2 inputs
        assert stat.get_partial_shape == 2
        assert stat.set_partial_shape == 2
        assert stat.lastArgPartialShape == PartialShape([123, 2, 3, 4])
    def test_override_same_outputs(self, mock_argparse):

        main(argparse.ArgumentParser(), fem, 'ov_mock_mo_frontend')
        stat = get_model_statistic()

        # verify that 'override_all_inputs' was called
        # because outputs were not changed
        assert stat.override_all_inputs == 1
        assert stat.override_all_outputs == 0
        assert stat.extract_subgraph == 0
    def test_error_input_model_no_framework(self, mock_argparse):
        # Framework is not specified and 'abc.qwerty' is not supported
        # so MO shall not convert anything and produce specified error
        with self.assertLogs() as logger:
            main(argparse.ArgumentParser(), fem, None)

        stat = get_frontend_statistic()

        assert [s for s in logger.output if 'can not be deduced' in s]

        # verify that 'supported' was called
        assert stat.supported == 1
    def test_error_batch(self, mock_argparse):
        # First dimension doesn't look like a batch,
        # so MO shall not convert anything and produce specified error
        mock_return_partial_shape(PartialShape([122, 2, 3, 4]))
        with self.assertLogs() as logger:
            main(argparse.ArgumentParser(), fem, 'ov_mock_mo_frontend')

        stat = get_model_statistic()

        assert [s for s in logger.output if 'question=39' in s]

        # verify that 'get_element_type' was called
        assert stat.get_partial_shape == 1
        # verify that 'set_element_type' was not called
        assert stat.set_partial_shape == 0
    def test_simple_convert(self, mock_argparse):
        f = io.StringIO()
        with redirect_stdout(f):
            main(argparse.ArgumentParser(), fem, 'ov_mock_mo_frontend')
            out = f.getvalue()

        xml_file = re.search(r'\[ SUCCESS \] XML file: (.*)', out).\
            group(1).replace("\r", "")
        bin_file = re.search(r'\[ SUCCESS \] BIN file: (.*)', out).\
            group(1).replace("\r", "")
        assert xml_file and bin_file

        # verify that 'convert' was called, and 'supported' was not
        stat = get_frontend_statistic()
        assert stat.convert_model == 1
        assert stat.supported == 0
        # verify that meta info is added to XML file
        with open(xml_file) as file:
            assert 'ov_mock_mo_frontend' in file.read()
Exemplo n.º 10
0
    def test_convert_framework_discover(self, mock_argparse):
        f = io.StringIO()
        with redirect_stdout(f):
            main(argparse.ArgumentParser(), fem, None)
            out = f.getvalue()

        xml_file = re.search(r'\[ SUCCESS \] XML file: (.*)', out). \
            group(1).replace("\r", "")
        bin_file = re.search(r'\[ SUCCESS \] BIN file: (.*)', out). \
            group(1).replace("\r", "")
        assert xml_file and bin_file

        # verify that 'convert', 'supported' and 'get_name' were called
        stat = get_frontend_statistic()
        assert stat.convert_model == 1
        assert stat.supported == 1
        assert stat.get_name > 0

        # verify that meta info is added to XML file
        with open(xml_file) as file:
            assert 'openvino_mock_mo_frontend' in file.read()
Exemplo n.º 11
0
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import sys

from openvino.tools.mo.utils.cli_parser import get_onnx_cli_parser

if __name__ == "__main__":
    from openvino.tools.mo.main import main
    from openvino.frontend import FrontEndManager  # pylint: disable=no-name-in-module,import-error

    sys.exit(main(get_onnx_cli_parser(), FrontEndManager(), 'onnx'))
Exemplo n.º 12
0
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import sys

from openvino.tools.mo.utils.cli_parser import get_all_cli_parser

from openvino.frontend import FrontEndManager  # pylint: disable=no-name-in-module,import-error

if __name__ == "__main__":
    from openvino.tools.mo.main import main
    fem = FrontEndManager()
    sys.exit(main(get_all_cli_parser(fem), fem, 'paddle'))
Exemplo n.º 13
0
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import sys

from openvino.tools.mo.utils.cli_parser import get_tf_cli_parser

if __name__ == "__main__":
    from openvino.tools.mo.main import main
    sys.exit(main(get_tf_cli_parser(), None, 'tf'))
Exemplo n.º 14
0
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import sys

from openvino.tools.mo.utils.cli_parser import get_caffe_cli_parser

if __name__ == "__main__":
    from openvino.tools.mo.main import main
    sys.exit(main(get_caffe_cli_parser(), None, 'caffe'))
Exemplo n.º 15
0
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import sys

from openvino.tools.mo.utils.cli_parser import get_mxnet_cli_parser

if __name__ == "__main__":
    from openvino.tools.mo.main import main
    sys.exit(main(get_mxnet_cli_parser(), None, 'mxnet'))
Exemplo n.º 16
0
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import sys

from openvino.tools.mo.utils.cli_parser import get_kaldi_cli_parser

if __name__ == "__main__":
    from openvino.tools.mo.main import main
    sys.exit(main(get_kaldi_cli_parser(), None, 'kaldi'))
Exemplo n.º 17
0
 def test_FrameworkError(self, mock_argparse, mock_driver):
     with self.assertLogs() as logger:
         main(argparse.ArgumentParser(), None, 'framework_string')
         self.assertEqual(logger.output, ['ERROR:root:FW ERROR MESSAGE'])