Exemple #1
0
def load_caffe_proto_model(caffe_pb2, proto_path: str, model_path: [str, None] = None):
    # 1. python protobuf is used
    if api_implementation._implementation_type == 'python':
        message = 'Please expect that Model Optimizer conversion might be slow. ' \
                  'You are currently using Python protobuf library implementation. \n'
        try:
            from google.protobuf.pyext import cpp_message
            # Check os windows and env variable PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION
            if os.name == 'nt' and os.environ.get('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', default='') != 'cpp':
                # 2. cpp implementation is available but not used
                message += 'However, cpp implementation is available, you can boost ' \
                           'model conversion by setting PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION env variable to cpp. \n' \
                           'Run: set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp \n'
        except ImportError:
            # 3. cpp implementation is not available
            message += 'However you can use the C++ protobuf implementation that is supplied with the OpenVINO toolkit' \
                       'or build protobuf library from sources. \n' \
                       'Navigate to "install_prerequisites" folder and run: ' \
                       'python -m easy_install protobuf-3.5.1-py($your_python_version)-win-amd64.egg \n' \
                       'set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp'
        print(message + '\n\n' + refer_to_faq_msg(80))

    # Read proto layers
    try:
        proto = caffe_pb2.NetParameter()
        with open(proto_path, "r") as file:
            text_format.Merge(str(file.read()), proto)
    except Exception as e:
        log.error('Exception message: {}\n\n'.format(e) +
                  '    Possible reasons:\n' +
                  '      1. {} does not exist\n'.format(proto_path) +
                  '      2. {} does not have a valid structure, for example, it was downloaded as html\n'.format(
                      proto_path) +
                  '      3. {} contains custom layers or attributes that are not supported\n'.format(proto_path) +
                  '         in Model Optimizer by default.\n\n' +
                  '    After you made sure that {} has a valid structure and still see this issue, then\n'.format(
                      proto_path) +
                  '    you need to generate a python parser for caffe.proto that was used when the model\n' +
                  '    was created.\n' +
                  '    Run "python3 generate_caffe_pb2.py --input_proto ${PATH_TO_CAFFE}/src/caffe/proto/caffe.proto"' +
                  refer_to_faq_msg(1) + '\n\n', extra={'framework_error': True})
        raise FrameworkError('Model Optimizer is not able to parse {}'.format(proto_path)) from e

    # Read model layer if exists
    model = None
    try:
        if model_path:
            model = caffe_pb2.NetParameter()
            with open(model_path, "rb") as infile:
                map = mmap.mmap(infile.fileno(), 0, access=mmap.ACCESS_READ)
                model.MergeFromString(map)
    except Exception as e:
        log.error('Exception message: {}\n\n'.format(e) +
                  '    Possible reasons:\n' +
                  '      1. {} does not exist\n'.format(model_path) +
                  '      2. {} does not have a valid structure\n'.format(model_path), extra={'framework_error': True})
        raise FrameworkError('Model Optimizer is not able to parse {}'.format(model_path)) from e

    return proto, model
Exemple #2
0
def read_file_to_graph_def(graph_def: [tf_v1.GraphDef, tf_v1.MetaGraphDef], graph_file_name: str = "",
                           is_binary: bool = True):
    """
    Reads file to protobuf
    :param graph_def: GraphDef orr MetaGraphDef object to store the network
    :param graph_file_name: path to file with graph
    :param is_binary: flag to switch between binary and test protobuf format of graph file
    :return: GraphDef or MetaGaphDef containing the network with cleared device info.
    """
    try:
        if is_binary:
            with open(graph_file_name, "rb") as f:
                graph_def.ParseFromString(f.read())
        else:
            with open(graph_file_name, "r") as f:
                text_format.Merge(f.read(), graph_def)
        nodes_to_clear_device = graph_def.node if isinstance(graph_def, tf_v1.GraphDef) else graph_def.graph_def.node
        for node in nodes_to_clear_device:
            node.device = ""
    except Exception as e:
        raise FrameworkError(
            'TensorFlow cannot read the model file: "{}" is incorrect TensorFlow model file. '
            '\nThe file should contain one of the following TensorFlow graphs:'
            '\n1. frozen graph in text or binary format'
            '\n2. inference graph for freezing with checkpoint (--input_checkpoint) in text or binary format'
            '\n3. meta graph'
            '\n\nMake sure that --input_model_is_text is provided for a model in text format. '
            'By default, a model is interpreted in binary format. Framework error details: {}. ' +
            refer_to_faq_msg(43),
            graph_file_name,
            str(e)
        ) from e
    return graph_def
Exemple #3
0
    def load(self, graph: Graph):
        argv = graph.graph['cmd_params']
        try:
            model_nodes, model_params, model_name, iteration_number = load_symbol_def(
                argv.input_model, argv.input_symbol, argv.input,
                argv.nd_prefix_name, argv.pretrained_model_name,
                argv.legacy_mxnet_model)
        except (ValueError, mxnet.base.MXNetError) as e:
            raise FrameworkError(
                'The following error happened while loading mxnet model {}: {}. '
                + refer_to_faq_msg(53), argv.input_model, str(e)) from e

        if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
            save_params_file(model_name, model_params._arg_params,
                             model_params._aux_params, iteration_number)

        update_extractors_with_extensions(mxnet_op_extractors)
        symbol2nx(graph, model_nodes, model_params, argv.input)
        graph.check_empty_graph(
            'symbol2nx. It may happen due to problems with loaded model')

        graph.graph['layout'] = 'NCHW'
        graph.graph['fw'] = 'mxnet'
        graph.graph[
            'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3

        extract_node_attrs(graph, mxnet_op_extractor)
Exemple #4
0
class TestMainErrors(unittest.TestCase):
    @patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(generate_deprecated_IR_V7=False))
    @patch('mo.main.driver', side_effect=FrameworkError('FW ERROR MESSAGE'))
    def test_FrameworkError(self, mock_argparse, mock_driver):
        with self.assertLogs() as logger:
            main(argparse.ArgumentParser(), 'framework_string')
            self.assertEqual(logger.output, ['ERROR:root:FW ERROR MESSAGE'])
Exemple #5
0
def load_onnx_model(file_name: str):
    try:
        onnx_model = onnx.load(file_name)
    except Exception as e:
        raise FrameworkError(
            'Cannot read the model file: "{}" is incorrect ONNX model file. Details: {}',
            file_name, str(e)) from e

    return onnx_model
Exemple #6
0
def apply_transform(graph: Graph, replacer_cls, **kwargs):
    """
    Safely executes transform if it should be and validates graph after transform execution
    """
    replacer = replacer_cls()
    replacement_id = 'REPLACEMENT_ID'
    if hasattr(replacer, 'replacement_id'):
        replacement_id = replacer.replacement_id

    if hasattr(replacer, 'enabled') and not replacer.enabled:
        log.info("Skip replacer {} (enabled = False)".format(replacer_cls))
        return

    if hasattr(replacer, 'graph_condition') and \
            not all([condition(graph) for condition in replacer.graph_condition]):
        log.info("Skip replacer {} (graph_condition not satisfied)".format(replacer_cls))
        return

    log.debug("Run replacer {}".format(replacer_cls))

    try:
        if hasattr(replacer, 'run_not_recursively') and replacer.run_not_recursively:
            replacer.find_and_replace_pattern(graph)
        else:
            for_graph_and_each_sub_graph_recursively(graph, replacer.find_and_replace_pattern)

        if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up:
            for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

        if hasattr(replacer, 'force_shape_inference') and replacer.force_shape_inference:
            shape_inference(graph)

        if hasattr(replacer, 'run_not_recursively') and replacer.run_not_recursively:
            graph.check_empty_graph(replacer_cls)
            graph.check_shapes_consistency()
        else:
            for_graph_and_each_sub_graph_recursively(graph, lambda _: graph.check_empty_graph(replacer_cls))
            for_graph_and_each_sub_graph_recursively(graph, lambda _: graph.check_shapes_consistency())

    except Error as err:
        raise Error('Exception occurred during running replacer "{}" ({}): {}'.format(
            replacement_id,
            replacer_cls,
            str(err).replace('[REPLACEMENT_ID]', replacement_id),
        )) from err
    except FrameworkError as err:
        raise FrameworkError('{}'.format(str(err))) from err
    except Exception as err:
        raise Exception('Exception occurred during running replacer "{} ({})": {}'.format(
            replacement_id,
            replacer_cls,
            str(err).replace('[REPLACEMENT_ID]', replacement_id),
        )) from err
Exemple #7
0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str,
           output_dir: str):
    meta_info = get_meta_info(argv)

    try:
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(
            input_model, argv.input_symbol, argv.input, argv.nd_prefix_name,
            argv.pretrained_model_name, argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
        raise FrameworkError(
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53), input_model, str(e)) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params,
                         model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
    graph.check_empty_graph(
        'symbol2nx. It may happen due to problems with loaded model')

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3

    if graph.graph['cmd_params'].generate_experimental_IR_V10:
        version = 10
    else:
        version = 6
    graph.graph[
        'ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version

    extract_node_attrs(graph, mxnet_op_extractor)

    # --------------------------------- LOAD END ------------------------------------------------------

    class_registration.apply_replacements(graph, [
        class_registration.ClassType.FRONT_REPLACER,
        class_registration.ClassType.MIDDLE_REPLACER,
        class_registration.ClassType.BACK_REPLACER
    ])

    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0
Exemple #8
0
def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpoint: str = "",
                      model_dir: str = "", saved_model_tags: list = [], meta_graph_file: str = "",
                      user_output_node_names_list: list = []):
    # As a provisional solution, use a native TF methods to load a model protobuf
    graph_def = tf_v1.GraphDef()
    if isinstance(graph_file_name, str) and (re.match('.*\.(ckpt|meta)$', graph_file_name)):
        print('[ WARNING ] The value for the --input_model command line parameter ends with ".ckpt" or ".meta" '
              'extension.\n'
              'It means that the model is not frozen.\n'
              'To load non frozen model to Model Optimizer run:'
              '\n\n1. For "*.ckpt" file:'
              '\n- if inference graph is in binary format'
              '\npython3 mo_tf.py --input_model "path/to/inference_graph.pb" --input_checkpoint "path/to/*.ckpt"'
              '\n- if inference graph is in text format'
              '\npython3 mo_tf.py --input_model "path/to/inference_graph.pbtxt" --input_model_is_text '
              '--input_checkpoint "path/to/*.ckpt"'
              '\n\n2. For "*.meta" file:'
              '\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"')
    variables_values = {}
    try:
        if graph_file_name and not meta_graph_file and not checkpoint:
            # frozen graph
            return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values, 'tf'
        if graph_file_name and not meta_graph_file and checkpoint:
            # inference graph and checkpoint
            graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary)
            outputs = get_output_node_names_list(graph_def, user_output_node_names_list)
            if os.path.isfile(checkpoint):
                graph_def = freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs)
            elif os.path.isdir(checkpoint):
                graph_def, variables_values = freeze_checkpoints(graph_def=graph_def, checkpoint_dir=checkpoint,
                                                                 output_node_names=outputs)
            # we are sure that checkpoint is existing file or directory due to cli_parser configuration
            return graph_def, variables_values, 'tf'
        if not graph_file_name and meta_graph_file:
            meta_graph_file = deducing_metagraph_path(meta_graph_file)
            input_meta_graph_def = read_file_to_graph_def(tf_v1.MetaGraphDef(), meta_graph_file, is_binary)
            # pylint: disable=no-member
            with tf_v1.Session() as sess:
                restorer = tf_v1.train.import_meta_graph(input_meta_graph_def)
                restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file))
                outputs = get_output_node_names_list(input_meta_graph_def.graph_def, user_output_node_names_list)
                graph_def = tf_v1.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def,
                                                                            outputs)
                return graph_def, variables_values, 'tf'
        if model_dir:
            # saved model directory
            try:
                env_setup = get_environment_setup("tf")
                # enable eager execution temporarily while TensorFlow 2 model is being loaded
                tf_v1.enable_eager_execution()
                # code to extract GraphDef for TF 2.0 SavedModel format
                # tf.saved_model.load function throws TypeError for TF 1.x SavedModel format in case TF 1.x installed
                imported = tf.saved_model.load(model_dir, saved_model_tags) # pylint: disable=E1120
                # to get a signature by key throws KeyError for TF 1.x SavedModel format in case TF 2.x installed
                concrete_func = imported.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
                # the aggressive inlining parameter needs to freeze a table of embeddings for Keras Embedding operation
                # and a model with Embedding operation cannot properly converted to IR without this function parameter
                if "tensorflow" in env_setup and env_setup["tensorflow"] >= LooseVersion("2.2.0"):
                    frozen_func = convert_variables_to_constants_v2(concrete_func,
                                                                    lower_control_flow=False,
                                                                    aggressive_inlining=True)  # pylint: disable=E1123
                else:
                    frozen_func = convert_variables_to_constants_v2(concrete_func,
                                                                    lower_control_flow=False)  # pylint: disable=E1123
                graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
                # disable eager execution since next steps are executed with a graph in non-eager mode
                tf_v1.disable_eager_execution()
                return graph_def, variables_values, 'tf2'
            except (TypeError, KeyError):
                # disable eager execution since TensorFlow 1 model is handled
                tf_v1.disable_eager_execution()
                # code to extract GraphDef for TF 1.0 SavedModel format
                tags = saved_model_tags if saved_model_tags is not None else [tf_v1.saved_model.tag_constants.SERVING]
                with tf_v1.Session() as sess:
                    meta_graph_def = tf_v1.saved_model.loader.load(sess, tags, model_dir)
                    outputs = get_output_node_names_list(meta_graph_def.graph_def, user_output_node_names_list)
                    graph_def = tf_v1.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs)
                    return graph_def, variables_values, 'tf'
            except Exception as e:
                raise FrameworkError('SavedModel format load failure: {}', e) from e
    except Exception as e:
        raise FrameworkError('Cannot load input model: {}', e) from e
    raise Error("Unknown configuration of input model parameters")
Exemple #9
0
def load_tf_graph_def(graph_file_name: str = "",
                      is_binary: bool = True,
                      checkpoint: str = "",
                      model_dir: str = "",
                      saved_model_tags: list = [],
                      meta_graph_file: str = "",
                      user_output_node_names_list: list = []):
    # As a provisional solution, use a native TF methods to load a model protobuf
    graph_def = tf_v1.GraphDef()
    if isinstance(graph_file_name, str) and (re.match('.*\.(ckpt|meta)$',
                                                      graph_file_name)):
        print(
            '[ WARNING ] The value for the --input_model command line parameter ends with ".ckpt" or ".meta" '
            'extension.\n'
            'It means that the model is not frozen.\n'
            'To load non frozen model to Model Optimizer run:'
            '\n\n1. For "*.ckpt" file:'
            '\n- if inference graph is in binary format'
            '\npython3 mo_tf.py --input_model "path/to/inference_graph.pb" --input_checkpoint "path/to/*.ckpt"'
            '\n- if inference graph is in text format'
            '\npython3 mo_tf.py --input_model "path/to/inference_graph.pbtxt" --input_model_is_text '
            '--input_checkpoint "path/to/*.ckpt"'
            '\n\n2. For "*.meta" file:'
            '\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"')
    variables_values = {}
    try:
        if graph_file_name and not meta_graph_file and not checkpoint:
            # frozen graph
            return read_file_to_graph_def(graph_def, graph_file_name,
                                          is_binary), variables_values
        if graph_file_name and not meta_graph_file and checkpoint:
            # inference graph and checkpoint
            graph_def = read_file_to_graph_def(graph_def, graph_file_name,
                                               is_binary)
            outputs = get_output_node_names_list(graph_def,
                                                 user_output_node_names_list)
            if os.path.isfile(checkpoint):
                graph_def = freeze_checkpoint(graph_def=graph_def,
                                              checkpoint=checkpoint,
                                              output_node_names=outputs)
            elif os.path.isdir(checkpoint):
                graph_def, variables_values = freeze_checkpoints(
                    graph_def=graph_def,
                    checkpoint_dir=checkpoint,
                    output_node_names=outputs)
            # we are sure that checkpoint is existing file or directory due to cli_parser configuration
            return graph_def, variables_values
        if not graph_file_name and meta_graph_file:
            meta_graph_file = deducing_metagraph_path(meta_graph_file)
            input_meta_graph_def = read_file_to_graph_def(
                tf_v1.MetaGraphDef(), meta_graph_file, is_binary)
            # pylint: disable=no-member
            with tf_v1.Session() as sess:
                restorer = tf_v1.train.import_meta_graph(input_meta_graph_def)
                restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file))
                outputs = get_output_node_names_list(
                    input_meta_graph_def.graph_def,
                    user_output_node_names_list)
                graph_def = tf_v1.graph_util.convert_variables_to_constants(
                    sess, input_meta_graph_def.graph_def, outputs)
                return graph_def, variables_values
        if model_dir:
            # saved model directory
            tags = saved_model_tags if saved_model_tags is not None else [
                tf_v1.saved_model.tag_constants.SERVING
            ]
            with tf_v1.Session() as sess:
                meta_graph_def = tf_v1.saved_model.loader.load(
                    sess, tags, model_dir)
                outputs = get_output_node_names_list(
                    meta_graph_def.graph_def, user_output_node_names_list)
                graph_def = tf_v1.graph_util.convert_variables_to_constants(
                    sess, meta_graph_def.graph_def, outputs)
                return graph_def, variables_values
    except Exception as e:
        raise FrameworkError('Cannot load input model: {}', e) from e
    raise Error("Unknown configuration of input model parameters")
Exemple #10
0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, outputs: list, output_dir: str,
           scale: float,
           placeholder_shapes: [None, list, np.array] = None,
           mean_scale_values: [dict, list] = ()):
    meta_info = get_meta_info(argv)

    try:
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(input_model, argv.input_symbol,
                                                                                  argv.input,
                                                                                  argv.nd_prefix_name,
                                                                                  argv.pretrained_model_name,
                                                                                  argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
        raise FrameworkError(
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53),
            input_model,
            str(e)
        ) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params, model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
    check_empty_graph(graph, 'symbol2nx. It may happen due to problems with loaded model')

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
    graph = extract_node_attrs(graph, mxnet_op_extractor)
    check_softmax_node_inputs(graph)

    user_shapes, packed_outputs, _ = user_data_repack(graph, placeholder_shapes, outputs, None)
    output_op_nodes = add_output_ops(graph, packed_outputs)
    input_op_nodes = add_input_ops(graph, user_shapes, True)

    try:
        override_placeholder_shapes(graph, user_shapes, argv.batch)
    except ValueError as err:
        raise Error(
            'The following error happened while processing input shapes: {}. ' +
            refer_to_faq_msg(54),
            str(err)
        ) from err
    check_empty_graph(graph, 'add_output_ops and add_input_ops')

    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
    add_input_data_to_prior_boxes(graph, argv.input)

    graph = create_tensor_nodes(graph)

    graph_clean_up(graph)
    remove_output_ops(graph)
    mark_outputs(graph)
    remove_output_ops(graph)

    graph_clean_up(graph)

    log.debug("After removing specific nodes for output")

    print_graph_stat(graph)

    graph = partial_infer(graph)
    graph_clean_up(graph)
    check_empty_graph(graph, 'partial_infer')

    duplicate_shared_weights(graph)

    scale_input(graph, scale)
    add_mean_scale_values(graph, mean_scale_values)

    remove_op_nodes(graph, {'identity': True})

    graph_clean_up(graph)

    class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
    fuse_pad(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    convert_batch_norm(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    fuse_pad(graph)

    # Converting Mul->Add to ScaleShift node
    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up(graph)

    convert_mul_add_to_power(graph)
    convert_add_to_scaleshift(graph)  # scale = 1
    convert_mul_to_scaleshift(graph)  # biases = 0

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)

    prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0
Exemple #11
0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str,
           output_dir: str):
    meta_info = get_meta_info(argv)

    try:
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(
            input_model, argv.input_symbol, argv.input, argv.nd_prefix_name,
            argv.pretrained_model_name, argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
        raise FrameworkError(
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53), input_model, str(e)) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params,
                         model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
    graph.check_empty_graph(
        'symbol2nx. It may happen due to problems with loaded model')

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
    extract_node_attrs(graph, mxnet_op_extractor)

    # --------------------------------- LOAD END ------------------------------------------------------
    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)
    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    fuse_pad(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    convert_batch_norm(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    fuse_pad(graph)

    # Converting Mul->Add to ScaleShift node
    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up(graph)

    convert_mul_add_to_power(graph)
    graph_clean_up(graph)
    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up(graph)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0