Exemplo n.º 1
0
def driver_R5(onnx_modelproto_bytes,
              precision: str,
              output_model_name: str,
              outputs: list,
              output_dir: str,
              scale: float,
              user_shapes: [None, list, np.array] = None,
              mean_scale_values: [dict, list] = ()):

    try:
        model_proto = onnx.load_from_string(bytes(onnx_modelproto_bytes))
    except Exception as e:
        print("[python] onnx exception: ", str(e))

    model_graph = model_proto.graph  # pylint: disable=no-member
    log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
    log.debug(
        "Number of all input ports (not true inputs) in graph_def: {}".format(
            len(model_graph.input)))
    log.debug("Number of initializers in graph_def: {}".format(
        len(model_graph.initializer)))
    log.debug("Number of real inputs in graph_def: {}".format(
        len(model_graph.input) - len(model_graph.initializer)))
    update_extractors_with_extensions(onnx_op_extractors)

    try:
        graph = protobuf2nx(model_proto)
        log.debug("Number of nodes in NX graph: {}".format(
            graph.number_of_nodes()))
        graph.__setattr__(
            'name',
            output_model_name if output_model_name else model_proto.graph.name)  # pylint: disable=no-member
        graph.graph['layout'] = 'NCHW'
        graph.graph['fw'] = 'onnx'
        graph.graph[
            'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
        graph.graph['ir_version'] = 4
        extract_node_attrs(graph, lambda node:
                           (True, common_onnx_fields(node)))
    except Exception as e:
        raise Error(
            'Cannot pre-process ONNX graph after reading from model file "{}". '
            'File is corrupt or has unsupported format. Details: {}. ' +
            refer_to_faq_msg(44), model_file_name, str(e)) from e
    check_empty_graph(
        graph, 'protobuf2nx. It may happen due to problems with loaded model')
    packed_user_shapes, packed_outputs, _ = user_data_repack(
        graph, user_shapes, outputs, None)

    output_op_nodes = add_output_ops(graph, packed_outputs)
    input_op_nodes = add_input_ops(graph, packed_user_shapes, True)

    graph_clean_up(graph)
    check_empty_graph(graph, 'add_output_ops and add_input_ops')
    extract_node_attrs(
        graph, lambda node: onnx_op_extractor(
            node, check_for_duplicates(onnx_op_extractors)))

    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)

    create_tensor_nodes(graph)
    graph_clean_up(graph)

    override_placeholder_shapes(graph, packed_user_shapes)

    graph_clean_up(graph)
    remove_op_nodes(graph, {'op': 'Identity'})

    graph_clean_up(graph)

    remove_output_ops(graph)

    partial_infer(graph)
    graph_clean_up(graph)
    check_empty_graph(graph, 'partial_infer')

    input_op_nodes = add_input_ops(graph, packed_user_shapes, False)
    graph_clean_up(graph)
    check_empty_graph(graph, 'add_input_ops')

    scale_input(graph, scale)
    add_mean_scale_values(graph, mean_scale_values)

    convert_dilated_convolution(graph)
    graph_clean_up(graph)

    graph_clean_up(graph)

    remove_op_nodes(graph, {'op': 'Identity'})
    remove_useless_split(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    convert_gemm_to_fully_connected(graph)
    NormalizeFullyConnected().find_and_replace_pattern(graph)

    fuse_pad(graph)
    graph_clean_up(graph)

    convert_batch_norm(graph)
    graph_clean_up(graph)

    convert_scale_shift_to_mul_add(graph)
    graph_clean_up(graph)

    fuse_mul_add_sequence(graph)
    graph_clean_up(graph)

    fuse_linear_ops(graph)
    graph_clean_up(graph)

    grouped_convolutions_fusing(graph)
    graph_clean_up(graph)

    fuse_linear_ops(graph)
    graph_clean_up(graph)

    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up(graph)

    convert_mul_add_to_power(graph)
    graph_clean_up(graph)

    convert_reshape(graph)
    convert_add_to_scaleshift(graph)  # scale = 1
    convert_mul_to_scaleshift(graph)  # biases = 0

    fuse_pad(graph)
    graph_clean_up(graph)

    fuse_sequence_of_reshapes(graph)
    graph_clean_up(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    weights, xml_string = prepare_emit_ir(graph=graph,
                                          data_type=precision,
                                          output_dir=output_dir,
                                          output_model_name=output_model_name,
                                          meta_info={'unset': []})

    return weights, xml_string
Exemplo n.º 2
0
def driver_R1(onnx_modelproto_bytes,
              precision: str,
              output_model_name: str,
              outputs: list,
              output_dir: str,
              scale: float,
              user_shapes: [None, list, np.array] = None,
              mean_scale_values: [dict, list] = ()):

    try:
        model_proto = onnx.load_from_string(bytes(onnx_modelproto_bytes))
    except Exception as e:
        print("[python] onnx exception: ", str(e))

    model_graph = model_proto.graph  # pylint: disable=no-member

    update_extractors_with_extensions(onnx_op_extractors)

    try:
        graph = protobuf2nx(model_proto)
        log.debug("Number of nodes in NX graph: {}".format(
            graph.number_of_nodes()))
        graph.__setattr__(
            'name',
            output_model_name if output_model_name else model_proto.graph.name)  # pylint: disable=no-member
        graph.graph['layout'] = 'NCHW'
        graph.graph['cmd_params'] = argparse.Namespace(
            batch=None,
            data_type='float',
            disable_fusing=False,
            disable_gfusing=False,
            disable_resnet_optimization=False,
            enable_concat_optimization=False,
            extensions=mo_extensions,
            finegrain_fusing=None,
            framework='onnx',
            freeze_placeholder_with_value=None,
            generate_deprecated_IR_V2=False,
            input=None,
            input_model=None,
            input_shape=None,
            keep_shape_ops=False,
            log_level='ERROR',
            mean_scale_values={},
            mean_values=(),
            model_name=None,
            move_to_preprocess=False,
            output=None,
            output_dir='.',
            placeholder_shapes=None,
            reverse_input_channels=False,
            scale=None,
            scale_values=(),
            silent=False,
            version=False)
        graph.graph['fw'] = 'onnx'
        graph.graph[
            'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
        graph.graph['ir_version'] = 5
    except Exception as e:
        raise Error(
            'Cannot pre-process ONNX graph after reading from model file "{}". '
            'File is corrupt or has unsupported format. Details: {}. ' +
            refer_to_faq_msg(44), model_file_name, str(e)) from e
    graph.check_empty_graph(
        'protobuf2nx. It may happen due to problems with loaded model')
    extract_node_attrs(
        graph, lambda node: onnx_op_extractor(
            node, check_for_duplicates(onnx_op_extractors)))

    # --------------------------------- LOAD END ------------------------------------------------------
    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)
    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    fuse_pad(graph)
    graph_clean_up_onnx(graph)

    mark_unfused_nodes(graph, 'False')
    convert_batch_norm(graph)
    graph_clean_up_onnx(graph)

    convert_scale_shift_to_mul_add(graph)
    graph_clean_up_onnx(graph)

    fuse_mul_add_sequence(graph)
    graph_clean_up_onnx(graph)

    fuse_linear_ops(graph)
    graph_clean_up_onnx(graph)

    grouped_convolutions_fusing(graph)
    graph_clean_up_onnx(graph)

    fuse_linear_ops(graph)
    graph_clean_up_onnx(graph)

    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up_onnx(graph)

    convert_mul_add_to_power(graph)
    graph_clean_up_onnx(graph)

    convert_reshape(graph)
    graph_clean_up_onnx(graph)
    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up_onnx(graph)

    fuse_pad(graph)
    graph_clean_up_onnx(graph)

    fuse_sequence_of_reshapes(graph)
    graph_clean_up_onnx(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)

    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

    weights, xml_string = prepare_emit_ir(graph=graph,
                                          data_type=precision,
                                          output_dir=output_dir,
                                          output_model_name=output_model_name,
                                          meta_info={'unset': []})

    return weights, xml_string
Exemplo n.º 3
0
def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str):
    meta_info = get_meta_info(argv)

    model_proto = load_onnx_model(model_file_name)
    model_graph = model_proto.graph  # pylint: disable=no-member
    # print(model_graph)
    # assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
    log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
    log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input)))
    log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer)))
    log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer)))
    update_extractors_with_extensions(onnx_op_extractors)

    try:
        graph = protobuf2nx(model_proto)
        log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
        graph.__setattr__('name',
                          output_model_name if output_model_name else model_proto.graph.name)  # pylint: disable=no-member
        graph.graph['layout'] = 'NCHW'
        graph.graph['cmd_params'] = argv
        graph.graph['fw'] = 'onnx'
        graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
        graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
    except Exception as e:
        raise Error(
            'Cannot pre-process ONNX graph after reading from model file "{}". ' \
            'File is corrupt or has unsupported format. Details: {}. ' +
            refer_to_faq_msg(44),
            model_file_name,
            str(e)
        ) from e
    graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
    extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))

    # --------------------------------- LOAD END ------------------------------------------------------
    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
    class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)

    fuse_pad(graph)
    graph_clean_up_onnx(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
    convert_batch_norm(graph)
    graph_clean_up_onnx(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up_onnx(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up_onnx(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up_onnx(graph)

    if not argv.disable_gfusing:
        grouped_convolutions_fusing(graph)
        graph_clean_up_onnx(graph)
        if not argv.disable_fusing:
            fuse_linear_ops(graph)
            graph_clean_up_onnx(graph)

    AddQuantizeFuse().find_and_replace_pattern(graph)
    MulQuantizeFuse().find_and_replace_pattern(graph)

    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up_onnx(graph)

    convert_mul_add_to_power(graph)
    graph_clean_up_onnx(graph)

    convert_reshape(graph)
    graph_clean_up_onnx(graph)
    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up_onnx(graph)

    fuse_pad(graph)
    graph_clean_up_onnx(graph)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up_onnx(graph)

    fuse_sequence_of_reshapes(graph)
    graph_clean_up_onnx(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)

    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)

    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

    prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
                    meta_info=meta_info)

    return 0
Exemplo n.º 4
0
def driver(onnx_modelproto_bytes, precision: str, output_model_name: str,
           output_dir: str):
    try:
        model_proto = onnx.load_from_string(bytes(onnx_modelproto_bytes))
    except Exception as e:
        print("[python] onnx exception: ", str(e))

    model_graph = model_proto.graph  # pylint: disable=no-member
    log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
    log.debug(
        "Number of all input ports (not true inputs) in graph_def: {}".format(
            len(model_graph.input)))
    log.debug("Number of initializers in graph_def: {}".format(
        len(model_graph.initializer)))
    log.debug("Number of real inputs in graph_def: {}".format(
        len(model_graph.input) - len(model_graph.initializer)))
    update_extractors_with_extensions(onnx_op_extractors)

    try:
        graph = protobuf2nx(model_proto)
        log.debug("Number of nodes in NX graph: {}".format(
            graph.number_of_nodes()))
        graph.__setattr__(
            'name',
            output_model_name if output_model_name else model_proto.graph.name)  # pylint: disable=no-member
        graph.graph['layout'] = 'NCHW'
        graph.graph['fw'] = 'onnx'
        graph.graph[
            'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
        graph.graph['cmd_params'] = argparse.Namespace(
            batch=None,
            data_type='float',
            disable_fusing=False,
            disable_gfusing=False,
            disable_resnet_optimization=False,
            enable_concat_optimization=False,
            extensions=mo_extensions,
            finegrain_fusing=None,
            framework='onnx',
            freeze_placeholder_with_value=None,
            generate_deprecated_IR_V2=False,
            input=None,
            input_model=None,
            input_shape=None,
            keep_shape_ops=False,
            log_level='ERROR',
            mean_scale_values={},
            mean_values=(),
            model_name=None,
            move_to_preprocess=False,
            output=None,
            output_dir='.',
            placeholder_shapes=None,
            reverse_input_channels=False,
            scale=None,
            scale_values=(),
            silent=False,
            version=False,
            blobs_as_inputs=False,
            keep_quantize_ops_in_IR=False,
            generate_experimental_IR_V10=False)
        graph.graph['ir_version'] = 6

    except Exception as e:
        raise Error(
            'Cannot pre-process ONNX graph after reading from model file "{}". ' \
            'File is corrupt or has unsupported format. Details: {}. ' +
            refer_to_faq_msg(44),
            model_file_name,
            str(e)
        ) from e
    graph.check_empty_graph(
        'protobuf2nx. It may happen due to problems with loaded model')
    extract_node_attrs(
        graph, lambda node: onnx_op_extractor(
            node, check_for_duplicates(onnx_op_extractors)))

    # --------------------------------- LOAD END ------------------------------------------------------
    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)
    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    fuse_pad(graph)
    graph_clean_up_onnx(graph)

    for_graph_and_each_sub_graph_recursively(
        graph, convert_matmul_to_fully_connected)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, False)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
    convert_batch_norm(graph)
    graph_clean_up_onnx(graph)

    # Converting ScaleShift layer to Mul->Add
    convert_scale_shift_to_mul_add(graph)
    graph_clean_up_onnx(graph)

    # Fusing the sequences of Mul/Add operations
    fuse_mul_add_sequence(graph)
    graph_clean_up_onnx(graph)

    # Fusing linear operation to Convolution
    fuse_linear_ops(graph)
    graph_clean_up_onnx(graph)

    grouped_convolutions_fusing(graph)
    graph_clean_up_onnx(graph)

    fuse_linear_ops(graph)
    graph_clean_up_onnx(graph)

    MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph)
    FakeQuantizeFuse().find_and_replace_pattern(graph)

    AddFakeQuantizeFuse().find_and_replace_pattern(graph)
    MulFakeQuantizeFuse().find_and_replace_pattern(graph)

    convert_muladd_to_scaleshift(graph)
    graph_clean_up_onnx(graph)

    graph_clean_up_onnx(graph)
    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up_onnx(graph)

    fuse_pad(graph)
    graph_clean_up_onnx(graph)

    FuseReshapesSequence().find_and_replace_pattern(graph)
    RemoveRedundantReshapes().find_and_replace_pattern(graph)

    graph_clean_up_onnx(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    graph_clean_up_onnx(graph)
    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)

    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

    weights, xml_string = prepare_emit_ir(graph=graph,
                                          data_type=precision,
                                          output_dir=output_dir,
                                          output_model_name=output_model_name,
                                          meta_info={'unset': []})

    return weights, xml_string
Exemplo n.º 5
0
def driver(argv: argparse.Namespace,
           proto_file_name: str,
           model_file_name: str,
           output_model_name: str,
           output_dir: str,
           caffe_proto_path: str,
           mean_file: str = "",
           mean_file_offsets: tuple = None,
           custom_layers_mapping_path: str = None):
    log_step(argv.steps, 'LOAD')
    meta_info = get_meta_info(argv)

    caffe_pb2 = loader.import_caffe_pb2(caffe_proto_path)

    proto, model = loader.load_caffe_proto_model(caffe_pb2, proto_file_name,
                                                 model_file_name)

    update_extractors_with_extensions(
        caffe_type_extractors, argv.disable_omitting_optional if hasattr(
            argv, 'disable_omitting_optional') else False,
        argv.disable_flattening_optional_params if hasattr(
            argv, 'disable_flattening_optional_params') else False)

    try:
        graph, original_shapes = loader.caffe_pb_to_nx(proto, model)
    except ValueError as e:
        raise Error(
            'Invalid prototxt file: value error {}. ' + refer_to_faq_msg(11),
            str(e)) from e

    log.debug("After caffe_pb_to_nx")
    graph.print_graph_stat()
    graph.check_empty_graph('load_caffe_proto_model')

    graph.__setattr__('proto_path', proto_file_name)
    graph.__setattr__('caffemodel_path', model_file_name)
    graph.__setattr__('name',
                      getattr(proto, 'name', None) or output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'caffe'
    if graph.graph['cmd_params'].generate_experimental_IR_V10:
        version = 10
    else:
        version = 6
    graph.graph[
        'ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version

    custom_layers_map = custom_layers_mapping.load_layers_xml(
        custom_layers_mapping_path)
    custom_layers_mapping.update_extractors(
        caffe_type_extractors,
        custom_layers_map, argv.disable_omitting_optional if hasattr(
            argv, 'disable_omitting_optional') else False,
        argv.enable_flattening_nested_params if hasattr(
            argv, 'enable_flattening_nested_params') else False)
    extract_node_attrs(
        graph, lambda node: caffe_extractor(
            node, check_for_duplicates(caffe_type_extractors)))

    # --------------------------------- LOAD END ------------------------------------------------------
    log_step(argv.steps, 'FRONT')
    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)
    log_step(argv.steps, 'MIDDLE')
    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # need this pass even without fusing to convert scale with 2 inputs
    convert_scale_shift_to_mul_add(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        convert_bn_to_mul_add(graph)
        graph_clean_up(graph)

        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    convert_muladd_to_scaleshift(graph)
    convert_matmul_to_fully_connected(graph)
    batch_norm_fuse(graph)
    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up(graph)

    log.debug("After graph_cleanup")
    graph.print_graph_stat()

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    FuseReshapesSequence().find_and_replace_pattern(graph)
    RemoveRedundantReshapes().find_and_replace_pattern(graph)

    input_names = find_inputs(graph)
    mf = []
    try:
        if mean_file and len(original_shapes) == 1:
            mf = loader.parse_mean(mean_file, original_shapes[input_names[0]],
                                   mean_file_offsets, caffe_pb2)
        elif mean_file:
            raise Error(
                'Mean file for topologies with multiple inputs is not supported. '
                + refer_to_faq_msg(9))
    except ValueError as e:
        raise Error(
            'Cannot load or process mean file: value error {}. ' +
            refer_to_faq_msg(10), str(e)) from e

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    graph_clean_up(graph)
    log_step(argv.steps, 'BACK')
    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    remove_const_ops(graph)
    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    remove_output_ops(graph)
    log_step(argv.steps, 'EMIT')
    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    mean_data=mf,
                    input_names=input_names,
                    meta_info=meta_info)
    return 0
Exemplo n.º 6
0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str,
           output_dir: str):
    log_step(argv.steps, 'LOAD')
    meta_info = get_meta_info(argv)

    try:
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(
            input_model, argv.input_symbol, argv.input, argv.nd_prefix_name,
            argv.pretrained_model_name, argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
        raise FrameworkError(
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53), input_model, str(e)) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params,
                         model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
    graph.check_empty_graph(
        'symbol2nx. It may happen due to problems with loaded model')

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3

    if graph.graph['cmd_params'].generate_experimental_IR_V10:
        version = 10
    else:
        version = 6
    graph.graph[
        'ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version

    extract_node_attrs(graph, mxnet_op_extractor)

    # --------------------------------- LOAD END ------------------------------------------------------
    log_step(argv.steps, 'FRONT')
    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)
    log_step(argv.steps, 'MIDDLE')
    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    fuse_pad(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    convert_batch_norm(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    fuse_pad(graph)

    # Converting Mul->Add to ScaleShift node
    convert_muladd_to_scaleshift(graph)
    graph_clean_up(graph)

    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up(graph)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(
        graph, convert_matmul_to_fully_connected)

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    graph_clean_up(graph)
    log_step(argv.steps, 'BACK')
    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

    log_step(argv.steps, 'EMIT')
    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0
Exemplo n.º 7
0
def driver(argv: argparse.Namespace,
           model_file_name: str,
           output_model_name: str,
           outputs: list,
           output_dir: str,
           scale: float,
           user_shapes: [None, list, np.array] = None,
           mean_scale_values: [dict, list] = ()):

    meta_info = get_meta_info(argv)

    model_proto = load_onnx_model(model_file_name)
    model_graph = model_proto.graph  # pylint: disable=no-member
    #print(model_graph)
    #assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
    log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
    log.debug(
        "Number of all input ports (not true inputs) in graph_def: {}".format(
            len(model_graph.input)))
    log.debug("Number of initializers in graph_def: {}".format(
        len(model_graph.initializer)))
    log.debug("Number of real inputs in graph_def: {}".format(
        len(model_graph.input) - len(model_graph.initializer)))
    update_extractors_with_extensions(onnx_op_extractors)

    try:
        graph = protobuf2nx(model_proto)
        log.debug("Number of nodes in NX graph: {}".format(
            graph.number_of_nodes()))
        graph.__setattr__(
            'name',
            output_model_name if output_model_name else model_proto.graph.name)  # pylint: disable=no-member
        graph.graph['layout'] = 'NCHW'
        graph.graph['cmd_params'] = argv
        graph.graph['fw'] = 'onnx'
        graph.graph[
            'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
        graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4
        # extract basic attributes earlier to enable some passes that relies on them before full attribute
        # extractor is called
        extract_node_attrs(graph, lambda node:
                           (True, common_onnx_fields(node)))
    except Exception as e:
        raise Error(
            'Cannot pre-process ONNX graph after reading from model file "{}". ' \
            'File is corrupt or has unsupported format. Details: {}. ' +
            refer_to_faq_msg(44),
            model_file_name,
            str(e)
        ) from e
    check_empty_graph(
        graph, 'protobuf2nx. It may happen due to problems with loaded model')
    packed_user_shapes, packed_outputs, _ = user_data_repack(
        graph, user_shapes, outputs, None)

    output_op_nodes = add_output_ops(graph, packed_outputs)
    input_op_nodes = add_input_ops(graph, packed_user_shapes, True)

    # this call of 'graph_clean_up' removes child nodes of outputs which is useful when custom output is specified
    graph_clean_up(graph)
    check_empty_graph(graph, 'add_output_ops and add_input_ops')
    extract_node_attrs(
        graph, lambda node: onnx_op_extractor(
            node, check_for_duplicates(onnx_op_extractors)))

    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)

    create_tensor_nodes(graph)
    graph_clean_up(graph)

    override_placeholder_shapes(graph, packed_user_shapes)
    override_batch(graph, argv.batch)

    graph_clean_up(graph)
    remove_op_nodes(graph, {'op': 'Identity'})

    graph_clean_up(graph)

    remove_output_ops(graph)

    partial_infer(graph)
    graph_clean_up(graph)
    check_empty_graph(graph, 'partial_infer')

    input_op_nodes = add_input_ops(graph, packed_user_shapes, False)
    graph_clean_up(graph)
    check_empty_graph(graph, 'add_input_ops')
    #change_placeholders_types_to_FP32(graph)

    scale_input(graph, scale)
    add_mean_scale_values(graph, mean_scale_values)

    convert_dilated_convolution(graph)
    graph_clean_up(graph)

    graph_clean_up(graph)

    remove_op_nodes(graph, {'op': 'Identity'})
    remove_useless_split(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    convert_gemm_to_fully_connected(graph)
    NormalizeFullyConnected().find_and_replace_pattern(graph)

    fuse_pad(graph)
    graph_clean_up(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
    convert_batch_norm(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_gfusing:
        grouped_convolutions_fusing(graph)
        graph_clean_up(graph)
        if not argv.disable_fusing:
            fuse_linear_ops(graph)
            graph_clean_up(graph)

    convert_muladd_to_scaleshift_or_power(graph)
    graph_clean_up(graph)

    convert_mul_add_to_power(graph)
    graph_clean_up(graph)

    convert_reshape(graph)
    convert_add_to_scaleshift(graph)  # scale = 1
    convert_mul_to_scaleshift(graph)  # biases = 0

    fuse_pad(graph)
    graph_clean_up(graph)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    fuse_sequence_of_reshapes(graph)
    graph_clean_up(graph)

    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    meta_info=meta_info)

    return 0