Beispiel #1
0
def collect_extenders(path: str):
    """
    A function to registrate all MO IR Reader extenders
    :param path: Path to Model Optimizer folder
    :return:
    """
    import_by_path(os.path.join(path, 'mo', 'utils', 'ir_reader', 'extenders'),
                   ['mo', 'utils', 'ir_reader', 'extenders'])
    update_registration(classes=[Extender], enabled_transforms=[], disabled_transforms=[])
Beispiel #2
0
def collect_ops(path: str):
    """
    A function to registrate all MO ops
    :param path: Path to Model Optimizer folder
    :return:
    """
    import_by_path(os.path.join(path, 'mo', 'ops'), ['mo', 'ops'])
    import_by_path(os.path.join(path, 'extensions', 'ops'), ['extensions', 'ops'])
    update_registration(classes=[Op, Activation, Elementwise, LogicalElementwise, ReduceOp, Scatter],
                        enabled_transforms=[], disabled_transforms=[])
def load_dir(framework: str, path: str, get_front_classes: callable):
    """
    Assuming the following sub-directory structure for path:

        front/
            <framework>/
                <other_files>.py
            <other_directories>/
            <other_files>.py
        ops/
            <ops_files>.py
        middle/
            <other_files>.py
        back/
            <other_files>.py

    This function loads modules in the following order:
        1. ops/<ops_files>.py
        2. front/<other_files>.py
        3. front/<framework>/<other_files>.py
        4. middle/<other_files>.py
        5. back/<other_files>.py

    Handlers loaded later override earlier registered handlers for an op.
    1, 2, 3 can concur for the same op, but 4 registers a transformation pass
    and it shouldn't conflict with any stuff loaded by 1, 2 or 3.
    It doesn't load files from front/<other_directories>
    """
    log.info("Importing extensions from: {}".format(path))
    root_dir, ext = os.path.split(path)
    sys.path.insert(0, root_dir)

    enabled_transforms, disabled_transforms = get_enabled_and_disabled_transforms(
    )

    front_classes = get_front_classes()
    internal_dirs = {
        ('ops', ): [Op],
        ('analysis', ): [AnalyzeAction],
        ('load', framework): [Loader],
        ('front', ): front_classes,
        ('front', framework): front_classes,
        ('middle', ): [MiddleReplacementPattern],
        ('back', ): [BackReplacementPattern]
    }

    if ext == 'mo':
        internal_dirs[('front', framework, 'extractors')] = front_classes

    for p in internal_dirs.keys():
        import_by_path(os.path.join(path, *p), [ext, *p])
        update_registration(internal_dirs[p], enabled_transforms,
                            disabled_transforms)
    sys.path.remove(root_dir)
    def find_and_replace_pattern(self, graph: Graph):
        argv = graph.graph['cmd_params']
        registry = CustomReplacementRegistry()
        registry.add_custom_replacement_description_from_config(
            argv.tensorflow_use_custom_operations_config)

        # automatically generate sub-classes for custom replacements that replace sub-graph with a single node
        for replacement_desc in registry.get_all_replacements_descriptions():
            if replacement_desc.has('op'):
                type('FrontReplacementFromConfigFileOp' + replacement_desc.op,
                     (FrontReplacementFromConfigFileOp, ),
                     {'replacement_id': replacement_desc.id})
        update_registration([FrontReplacementFromConfigFileOp],
                            *get_enabled_and_disabled_transforms())
Beispiel #5
0
def driver(argv: argparse.Namespace,
           proto_file_name: str,
           model_file_name: str,
           output_model_name: str,
           outputs: list,
           output_dir: str,
           scale: float,
           user_shapes: [None, list, np.array] = None,
           mean_scale_values: [dict, list] = (),
           mean_file: str = "",
           mean_file_offsets: tuple = None,
           custom_layers_mapping_path: str = None):
    meta_info = get_meta_info(argv)

    FusePermutesSequence.enabled = False

    proto, model = loader.load_caffe_proto_model(proto_file_name,
                                                 model_file_name)

    update_extractors_with_extensions(
        caffe_type_extractors, argv.disable_omitting_optional if hasattr(
            argv, 'disable_omitting_optional') else False,
        argv.disable_flattening_optional_params if hasattr(
            argv, 'disable_flattening_optional_params') else False)

    try:
        graph, original_shapes = loader.caffe_pb_to_nx(proto, model)
    except ValueError as e:
        raise Error(
            'Invalid prototxt file: value error {}. ' + refer_to_faq_msg(11),
            str(e)) from e

    log.debug("After caffe_pb_to_nx")
    print_graph_stat(graph)
    check_empty_graph(graph, 'load_caffe_proto_model')

    graph.__setattr__('proto_path', proto_file_name)
    graph.__setattr__('caffemodel_path', model_file_name)
    graph.__setattr__('name',
                      getattr(proto, 'name', None) or output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'caffe'
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4

    extract_node_attrs(graph, lambda node: (True, common_caffe_fields(node)))

    log.debug("After adding specific nodes for outputs")
    print_graph_stat(graph)

    custom_layers_map = custom_layers_mapping.load_layers_xml(
        custom_layers_mapping_path)
    custom_layers_mapping.update_extractors(
        caffe_type_extractors,
        custom_layers_map, argv.disable_omitting_optional if hasattr(
            argv, 'disable_omitting_optional') else False,
        argv.enable_flattening_nested_params if hasattr(
            argv, 'enable_flattening_nested_params') else False)

    extract_node_attrs(
        graph, lambda node: caffe_extractor(
            node, check_for_duplicates(caffe_type_extractors)))

    log.debug("After extract_node_attr")
    print_graph_stat(graph)

    packed_user_shapes, packed_outputs, freeze_placeholder = user_data_repack(
        graph, user_shapes, outputs, argv.freeze_placeholder_with_value)
    if argv.freeze_placeholder_with_value is not None:
        FreezePlaceholderValue.enabled = True
        FreezePlaceholderValue.replacement_dict = freeze_placeholder
        class_registration.update_registration([FrontReplacementSubgraph])
    output_op_nodes = add_output_ops(graph, packed_outputs)
    input_op_nodes = add_input_ops(graph, packed_user_shapes, True)
    override_placeholder_shapes(graph, packed_user_shapes)
    override_batch(graph, argv.batch)
    graph_clean_up(graph)
    check_empty_graph(graph, 'add_output_ops and add_input_ops')
    class_registration.apply_replacements(
        graph, class_registration.ClassType.FRONT_REPLACER)

    graph = create_tensor_nodes(graph)

    log.debug("After create_tensor_nodes")
    print_graph_stat(graph)

    remove_op_nodes(graph, {'op': 'Identity'})
    remove_output_ops(graph)
    graph_clean_up(graph)

    log.debug("After removing specific nodes for output")
    print_graph_stat(graph)

    # you need to pass required network outputs here
    # but we don't have a way yet, so just passing all discovered sinks
    mark_outputs(graph)
    graph_clean_up(graph)
    log.debug("After graph_cleanup")
    print_graph_stat(graph)

    graph = partial_infer(graph)
    log.debug("After partial_infer")
    print_graph_stat(graph)
    check_empty_graph(graph, 'partial_infer')
    duplicate_shared_weights(graph)

    input_op_nodes = add_input_ops(graph, packed_user_shapes, False)
    graph_clean_up(graph)
    check_empty_graph(graph, 'add_input_ops')
    scale_input(graph, scale)

    add_mean_scale_values(graph, mean_scale_values)

    log.debug("Split multi input convolutions")
    convert_multi_input_conv(graph)

    graph_clean_up(graph)
    log.debug("After graph_cleanup")
    print_graph_stat(graph)

    remove_op_nodes(graph, {'op': 'Dropout'})
    remove_op_nodes(graph, {'phase': 0})
    graph_clean_up(graph)

    class_registration.apply_replacements(
        graph, class_registration.ClassType.MIDDLE_REPLACER)

    mean_to_avgpool(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    #need this pass even without fusing to convert scale with 2 inputs
    convert_scale_shift_to_mul_add(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        convert_bn_to_mul_add(graph)
        graph_clean_up(graph)

        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    convert_muladd_to_scaleshift_or_power(graph)
    convert_matmul_to_fully_connected(graph)
    batch_norm_fuse(graph)
    convert_mul_add_to_power(graph)
    convert_add_to_scaleshift(graph)  # scale = 1
    convert_mul_to_scaleshift(graph)  # biases = 0

    graph_clean_up(graph)
    log.debug("After graph_cleanup")
    print_graph_stat(graph)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    fuse_sequence_of_reshapes(graph)

    input_names = find_inputs(graph)
    mf = []
    try:
        if mean_file and len(original_shapes) == 1:
            mf = loader.parse_mean(mean_file, original_shapes[input_names[0]],
                                   mean_file_offsets)
        elif mean_file:
            raise Error(
                'Mean file for topologies with multiple inputs is not supported. '
                + refer_to_faq_msg(9))
    except ValueError as e:
        raise Error(
            'Cannot load or process mean file: value error {}. ' +
            refer_to_faq_msg(10), str(e)) from e

    class_registration.apply_replacements(
        graph, class_registration.ClassType.BACK_REPLACER)

    prepare_emit_ir(graph=graph,
                    data_type=argv.data_type,
                    output_dir=output_dir,
                    output_model_name=output_model_name,
                    mean_data=mf,
                    input_names=input_names,
                    meta_info=meta_info)
    return 0
Beispiel #6
0
def update_registration():
    class_registration.update_registration([Op, FrontExtractorOp, FrontReplacementOp, FrontReplacementSubgraph,
                                            MXNetCustomFrontExtractorOp, MiddleReplacementPattern,
                                            BackReplacementPattern, FrontReplacementPattern])
Beispiel #7
0
def update_registration():
    class_registration.update_registration([
        Op, FrontExtractorOp, CaffePythonFrontExtractorOp, FrontReplacementOp,
        FrontReplacementPattern, FrontReplacementSubgraph,
        MiddleReplacementPattern, BackReplacementPattern
    ])
Beispiel #8
0
def update_registration():
    class_registration.update_registration([Op, FrontExtractorOp, FrontReplacementOp, FrontReplacementPattern,
                                            FrontReplacementSubgraph, FrontReplacementFromConfigFileSubGraph,
                                            FrontReplacementFromConfigFileOp, MiddleReplacementPattern,
                                            BackReplacementPattern, FrontReplacementFromConfigFileGeneral])