def collect_extenders(path: str): """ A function to registrate all MO IR Reader extenders :param path: Path to Model Optimizer folder :return: """ import_by_path(os.path.join(path, 'mo', 'utils', 'ir_reader', 'extenders'), ['mo', 'utils', 'ir_reader', 'extenders']) update_registration(classes=[Extender], enabled_transforms=[], disabled_transforms=[])
def collect_ops(path: str): """ A function to registrate all MO ops :param path: Path to Model Optimizer folder :return: """ import_by_path(os.path.join(path, 'mo', 'ops'), ['mo', 'ops']) import_by_path(os.path.join(path, 'extensions', 'ops'), ['extensions', 'ops']) update_registration(classes=[Op, Activation, Elementwise, LogicalElementwise, ReduceOp, Scatter], enabled_transforms=[], disabled_transforms=[])
def load_dir(framework: str, path: str, get_front_classes: callable): """ Assuming the following sub-directory structure for path: front/ <framework>/ <other_files>.py <other_directories>/ <other_files>.py ops/ <ops_files>.py middle/ <other_files>.py back/ <other_files>.py This function loads modules in the following order: 1. ops/<ops_files>.py 2. front/<other_files>.py 3. front/<framework>/<other_files>.py 4. middle/<other_files>.py 5. back/<other_files>.py Handlers loaded later override earlier registered handlers for an op. 1, 2, 3 can concur for the same op, but 4 registers a transformation pass and it shouldn't conflict with any stuff loaded by 1, 2 or 3. It doesn't load files from front/<other_directories> """ log.info("Importing extensions from: {}".format(path)) root_dir, ext = os.path.split(path) sys.path.insert(0, root_dir) enabled_transforms, disabled_transforms = get_enabled_and_disabled_transforms( ) front_classes = get_front_classes() internal_dirs = { ('ops', ): [Op], ('analysis', ): [AnalyzeAction], ('load', framework): [Loader], ('front', ): front_classes, ('front', framework): front_classes, ('middle', ): [MiddleReplacementPattern], ('back', ): [BackReplacementPattern] } if ext == 'mo': internal_dirs[('front', framework, 'extractors')] = front_classes for p in internal_dirs.keys(): import_by_path(os.path.join(path, *p), [ext, *p]) update_registration(internal_dirs[p], enabled_transforms, disabled_transforms) sys.path.remove(root_dir)
def find_and_replace_pattern(self, graph: Graph): argv = graph.graph['cmd_params'] registry = CustomReplacementRegistry() registry.add_custom_replacement_description_from_config( argv.tensorflow_use_custom_operations_config) # automatically generate sub-classes for custom replacements that replace sub-graph with a single node for replacement_desc in registry.get_all_replacements_descriptions(): if replacement_desc.has('op'): type('FrontReplacementFromConfigFileOp' + replacement_desc.op, (FrontReplacementFromConfigFileOp, ), {'replacement_id': replacement_desc.id}) update_registration([FrontReplacementFromConfigFileOp], *get_enabled_and_disabled_transforms())
def driver(argv: argparse.Namespace, proto_file_name: str, model_file_name: str, output_model_name: str, outputs: list, output_dir: str, scale: float, user_shapes: [None, list, np.array] = None, mean_scale_values: [dict, list] = (), mean_file: str = "", mean_file_offsets: tuple = None, custom_layers_mapping_path: str = None): meta_info = get_meta_info(argv) FusePermutesSequence.enabled = False proto, model = loader.load_caffe_proto_model(proto_file_name, model_file_name) update_extractors_with_extensions( caffe_type_extractors, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.disable_flattening_optional_params if hasattr( argv, 'disable_flattening_optional_params') else False) try: graph, original_shapes = loader.caffe_pb_to_nx(proto, model) except ValueError as e: raise Error( 'Invalid prototxt file: value error {}. ' + refer_to_faq_msg(11), str(e)) from e log.debug("After caffe_pb_to_nx") print_graph_stat(graph) check_empty_graph(graph, 'load_caffe_proto_model') graph.__setattr__('proto_path', proto_file_name) graph.__setattr__('caffemodel_path', model_file_name) graph.__setattr__('name', getattr(proto, 'name', None) or output_model_name) graph.graph['layout'] = 'NCHW' graph.graph['cmd_params'] = argv graph.graph['fw'] = 'caffe' graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 4 extract_node_attrs(graph, lambda node: (True, common_caffe_fields(node))) log.debug("After adding specific nodes for outputs") print_graph_stat(graph) custom_layers_map = custom_layers_mapping.load_layers_xml( custom_layers_mapping_path) custom_layers_mapping.update_extractors( caffe_type_extractors, custom_layers_map, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.enable_flattening_nested_params if hasattr( argv, 'enable_flattening_nested_params') else False) extract_node_attrs( graph, lambda node: caffe_extractor( node, check_for_duplicates(caffe_type_extractors))) log.debug("After extract_node_attr") print_graph_stat(graph) packed_user_shapes, packed_outputs, freeze_placeholder = user_data_repack( graph, user_shapes, outputs, argv.freeze_placeholder_with_value) if argv.freeze_placeholder_with_value is not None: FreezePlaceholderValue.enabled = True FreezePlaceholderValue.replacement_dict = freeze_placeholder class_registration.update_registration([FrontReplacementSubgraph]) output_op_nodes = add_output_ops(graph, packed_outputs) input_op_nodes = add_input_ops(graph, packed_user_shapes, True) override_placeholder_shapes(graph, packed_user_shapes) override_batch(graph, argv.batch) graph_clean_up(graph) check_empty_graph(graph, 'add_output_ops and add_input_ops') class_registration.apply_replacements( graph, class_registration.ClassType.FRONT_REPLACER) graph = create_tensor_nodes(graph) log.debug("After create_tensor_nodes") print_graph_stat(graph) remove_op_nodes(graph, {'op': 'Identity'}) remove_output_ops(graph) graph_clean_up(graph) log.debug("After removing specific nodes for output") print_graph_stat(graph) # you need to pass required network outputs here # but we don't have a way yet, so just passing all discovered sinks mark_outputs(graph) graph_clean_up(graph) log.debug("After graph_cleanup") print_graph_stat(graph) graph = partial_infer(graph) log.debug("After partial_infer") print_graph_stat(graph) check_empty_graph(graph, 'partial_infer') duplicate_shared_weights(graph) input_op_nodes = add_input_ops(graph, packed_user_shapes, False) graph_clean_up(graph) check_empty_graph(graph, 'add_input_ops') scale_input(graph, scale) add_mean_scale_values(graph, mean_scale_values) log.debug("Split multi input convolutions") convert_multi_input_conv(graph) graph_clean_up(graph) log.debug("After graph_cleanup") print_graph_stat(graph) remove_op_nodes(graph, {'op': 'Dropout'}) remove_op_nodes(graph, {'phase': 0}) graph_clean_up(graph) class_registration.apply_replacements( graph, class_registration.ClassType.MIDDLE_REPLACER) mean_to_avgpool(graph) # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes mark_unfused_nodes(graph, argv.finegrain_fusing) #need this pass even without fusing to convert scale with 2 inputs convert_scale_shift_to_mul_add(graph) graph_clean_up(graph) if not argv.disable_fusing: convert_bn_to_mul_add(graph) graph_clean_up(graph) fuse_mul_add_sequence(graph) graph_clean_up(graph) fuse_linear_ops(graph) graph_clean_up(graph) if not argv.disable_resnet_optimization: stride_optimization(graph) convert_muladd_to_scaleshift_or_power(graph) convert_matmul_to_fully_connected(graph) batch_norm_fuse(graph) convert_mul_add_to_power(graph) convert_add_to_scaleshift(graph) # scale = 1 convert_mul_to_scaleshift(graph) # biases = 0 graph_clean_up(graph) log.debug("After graph_cleanup") print_graph_stat(graph) if argv.reverse_input_channels: reverse_input_channels(graph) if argv.move_to_preprocess: move_scaleshift_to_preprocess(graph) graph_clean_up(graph) fuse_sequence_of_reshapes(graph) input_names = find_inputs(graph) mf = [] try: if mean_file and len(original_shapes) == 1: mf = loader.parse_mean(mean_file, original_shapes[input_names[0]], mean_file_offsets) elif mean_file: raise Error( 'Mean file for topologies with multiple inputs is not supported. ' + refer_to_faq_msg(9)) except ValueError as e: raise Error( 'Cannot load or process mean file: value error {}. ' + refer_to_faq_msg(10), str(e)) from e class_registration.apply_replacements( graph, class_registration.ClassType.BACK_REPLACER) prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name, mean_data=mf, input_names=input_names, meta_info=meta_info) return 0
def update_registration(): class_registration.update_registration([Op, FrontExtractorOp, FrontReplacementOp, FrontReplacementSubgraph, MXNetCustomFrontExtractorOp, MiddleReplacementPattern, BackReplacementPattern, FrontReplacementPattern])
def update_registration(): class_registration.update_registration([ Op, FrontExtractorOp, CaffePythonFrontExtractorOp, FrontReplacementOp, FrontReplacementPattern, FrontReplacementSubgraph, MiddleReplacementPattern, BackReplacementPattern ])
def update_registration(): class_registration.update_registration([Op, FrontExtractorOp, FrontReplacementOp, FrontReplacementPattern, FrontReplacementSubgraph, FrontReplacementFromConfigFileSubGraph, FrontReplacementFromConfigFileOp, MiddleReplacementPattern, BackReplacementPattern, FrontReplacementFromConfigFileGeneral])