def test_force_precision_parameter(self): precision = 'FP16' shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data, 'force_precision': precision})] ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes, edges_with_attrs=self.edges + self.new_edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}), ('const_data', {'shape': shape, 'value': data, 'force_precision': precision}), ('const', {'force_precision': precision})] ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp) #check that force precision was added to data and Const nodes force_precision_const_node = graph.nodes['data_node_const']['force_precision'] force_precision_new_data = graph.nodes['data_node_copy_']['force_precision'] self.assertEqual(force_precision_const_node, precision) self.assertEqual(force_precision_new_data, precision)
def test_one_bin_node(self): """Nothing should happen.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})], update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}}, ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph, last_node='next_node') self.assertTrue(flag, resp)
def test_two_nodes_with_bin(self): """Test case for data node with 2 consumers with bin edge attr. Nothing should happened.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes + [('next_node_2', {'kind': 'op'})], edges_with_attrs=self.edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})], update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}, ('data_node', 'next_node_2', 0): {'bin': 0}}, ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph, last_node='next_node') self.assertTrue(flag, resp)
def test_one_node(self): """We should add Const node and data node.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})] ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes, edges_with_attrs=self.edges + self.new_edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}), ('const_data', {'shape': shape, 'value': data})] ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp)
def test_two_nodes_one_bin(self): """Test case for two output nodes, one with 'bin' parameter, other without.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes + [('next_node_2', {'kind': 'op'})], edges_with_attrs=self.edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})], update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}}, ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes + [('next_node_2', {'kind': 'op'})], edges_with_attrs=self.edges + self.new_edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}), ('const_data', {'shape': shape, 'value': data})] ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp)
def driver(argv, input_model, output_model_name, output_dir): meta_info = get_meta_info(argv) EltwiseChecker.enabled = False try: graph, input_shapes = load_kaldi_model(input_model) except Exception as e: raise Error('Model Optimizer is not able to read Kaldi model {}. '.format(input_model) + refer_to_faq_msg(91)) from e graph.check_empty_graph('load_kaldi_nnet_model') graph.graph['cmd_params'] = argv graph.graph['fw'] = 'kaldi' graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5 update_extractors_with_extensions(kaldi_type_extractors) extract_node_attrs(graph, lambda node: kaldi_extractor(node)) # --------------------------------- LOAD END ------------------------------------------------------ class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER) graph = partial_infer(graph) # The order is intentional, firstly eliminate repeated, then remove redundant FuseRepeatedReshapes().find_and_replace_pattern(graph) EliminateRedundantReshape().find_and_replace_pattern(graph) graph.check_empty_graph('partial_infer') if argv.counts: try: counts = read_counts_file(argv.counts) except Exception as e: raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) + refer_to_faq_msg(92)) from e apply_biases_to_last_layer(graph, counts) if argv.remove_output_softmax: RemoveLastSoftMaxPattern().find_and_replace_pattern(graph) graph_clean_up(graph) log.debug("After removing softmax") graph.print_graph_stat() # Intentionally after all transformations KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph) remove_const_ops(graph) CreateConstNodesReplacement().find_and_replace_pattern(graph) remove_output_ops(graph) prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info) return 0
def driver_R1(onnx_modelproto_bytes, precision: str, output_model_name: str, outputs: list, output_dir: str, scale: float, user_shapes: [None, list, np.array] = None, mean_scale_values: [dict, list] = ()): try: model_proto = onnx.load_from_string(bytes(onnx_modelproto_bytes)) except Exception as e: print("[python] onnx exception: ", str(e)) model_graph = model_proto.graph # pylint: disable=no-member update_extractors_with_extensions(onnx_op_extractors) try: graph = protobuf2nx(model_proto) log.debug("Number of nodes in NX graph: {}".format( graph.number_of_nodes())) graph.__setattr__( 'name', output_model_name if output_model_name else model_proto.graph.name) # pylint: disable=no-member graph.graph['layout'] = 'NCHW' graph.graph['cmd_params'] = argparse.Namespace( batch=None, data_type='float', disable_fusing=False, disable_gfusing=False, disable_resnet_optimization=False, enable_concat_optimization=False, extensions=mo_extensions, finegrain_fusing=None, framework='onnx', freeze_placeholder_with_value=None, generate_deprecated_IR_V2=False, input=None, input_model=None, input_shape=None, keep_shape_ops=False, log_level='ERROR', mean_scale_values={}, mean_values=(), model_name=None, move_to_preprocess=False, output=None, output_dir='.', placeholder_shapes=None, reverse_input_channels=False, scale=None, scale_values=(), silent=False, version=False) graph.graph['fw'] = 'onnx' graph.graph[ 'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3 graph.graph['ir_version'] = 5 except Exception as e: raise Error( 'Cannot pre-process ONNX graph after reading from model file "{}". ' 'File is corrupt or has unsupported format. Details: {}. ' + refer_to_faq_msg(44), model_file_name, str(e)) from e graph.check_empty_graph( 'protobuf2nx. It may happen due to problems with loaded model') extract_node_attrs( graph, lambda node: onnx_op_extractor( node, check_for_duplicates(onnx_op_extractors))) # --------------------------------- LOAD END ------------------------------------------------------ class_registration.apply_replacements( graph, class_registration.ClassType.FRONT_REPLACER) class_registration.apply_replacements( graph, class_registration.ClassType.MIDDLE_REPLACER) fuse_pad(graph) graph_clean_up_onnx(graph) mark_unfused_nodes(graph, 'False') convert_batch_norm(graph) graph_clean_up_onnx(graph) convert_scale_shift_to_mul_add(graph) graph_clean_up_onnx(graph) fuse_mul_add_sequence(graph) graph_clean_up_onnx(graph) fuse_linear_ops(graph) graph_clean_up_onnx(graph) grouped_convolutions_fusing(graph) graph_clean_up_onnx(graph) fuse_linear_ops(graph) graph_clean_up_onnx(graph) convert_muladd_to_scaleshift_or_power(graph) graph_clean_up_onnx(graph) convert_mul_add_to_power(graph) graph_clean_up_onnx(graph) convert_reshape(graph) graph_clean_up_onnx(graph) convert_add_or_mul_to_scaleshift(graph) # scale = 1 graph_clean_up_onnx(graph) fuse_pad(graph) graph_clean_up_onnx(graph) fuse_sequence_of_reshapes(graph) graph_clean_up_onnx(graph) pattern = EltwiseInputNormalize() pattern.find_and_replace_pattern(graph) merge_nodes_permutations(graph) permute_data_nodes_attrs(graph) permute_op_nodes_attrs(graph) class_registration.apply_replacements( graph, class_registration.ClassType.BACK_REPLACER) for_graph_and_each_sub_graph_recursively(graph, remove_const_ops) CreateConstNodesReplacement().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, remove_output_ops) weights, xml_string = prepare_emit_ir(graph=graph, data_type=precision, output_dir=output_dir, output_model_name=output_model_name, meta_info={'unset': []}) return weights, xml_string
def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str): meta_info = get_meta_info(argv) model_proto = load_onnx_model(model_file_name) model_graph = model_proto.graph # pylint: disable=no-member # print(model_graph) # assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported" log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node))) log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input))) log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer))) log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer))) update_extractors_with_extensions(onnx_op_extractors) try: graph = protobuf2nx(model_proto) log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes())) graph.__setattr__('name', output_model_name if output_model_name else model_proto.graph.name) # pylint: disable=no-member graph.graph['layout'] = 'NCHW' graph.graph['cmd_params'] = argv graph.graph['fw'] = 'onnx' graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5 except Exception as e: raise Error( 'Cannot pre-process ONNX graph after reading from model file "{}". ' \ 'File is corrupt or has unsupported format. Details: {}. ' + refer_to_faq_msg(44), model_file_name, str(e) ) from e graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model') extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors))) # --------------------------------- LOAD END ------------------------------------------------------ class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER) class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER) fuse_pad(graph) graph_clean_up_onnx(graph) # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes mark_unfused_nodes(graph, argv.finegrain_fusing) # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift convert_batch_norm(graph) graph_clean_up_onnx(graph) if not argv.disable_fusing: # Converting ScaleShift layer to Mul->Add convert_scale_shift_to_mul_add(graph) graph_clean_up_onnx(graph) # Fusing the sequences of Mul/Add operations fuse_mul_add_sequence(graph) graph_clean_up_onnx(graph) # Fusing linear operation to Convolution fuse_linear_ops(graph) graph_clean_up_onnx(graph) if not argv.disable_gfusing: grouped_convolutions_fusing(graph) graph_clean_up_onnx(graph) if not argv.disable_fusing: fuse_linear_ops(graph) graph_clean_up_onnx(graph) AddQuantizeFuse().find_and_replace_pattern(graph) MulQuantizeFuse().find_and_replace_pattern(graph) convert_muladd_to_scaleshift_or_power(graph) graph_clean_up_onnx(graph) convert_mul_add_to_power(graph) graph_clean_up_onnx(graph) convert_reshape(graph) graph_clean_up_onnx(graph) convert_add_or_mul_to_scaleshift(graph) # scale = 1 graph_clean_up_onnx(graph) fuse_pad(graph) graph_clean_up_onnx(graph) if argv.reverse_input_channels: reverse_input_channels(graph) if argv.move_to_preprocess: move_scaleshift_to_preprocess(graph) graph_clean_up_onnx(graph) fuse_sequence_of_reshapes(graph) graph_clean_up_onnx(graph) pattern = EltwiseInputNormalize() pattern.find_and_replace_pattern(graph) merge_nodes_permutations(graph) permute_data_nodes_attrs(graph) permute_op_nodes_attrs(graph) class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER) for_graph_and_each_sub_graph_recursively(graph, remove_const_ops) CreateConstNodesReplacement().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, remove_output_ops) prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name, meta_info=meta_info) return 0
def driver(onnx_modelproto_bytes, precision: str, output_model_name: str, output_dir: str): try: model_proto = onnx.load_from_string(bytes(onnx_modelproto_bytes)) except Exception as e: print("[python] onnx exception: ", str(e)) model_graph = model_proto.graph # pylint: disable=no-member log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node))) log.debug( "Number of all input ports (not true inputs) in graph_def: {}".format( len(model_graph.input))) log.debug("Number of initializers in graph_def: {}".format( len(model_graph.initializer))) log.debug("Number of real inputs in graph_def: {}".format( len(model_graph.input) - len(model_graph.initializer))) update_extractors_with_extensions(onnx_op_extractors) try: graph = protobuf2nx(model_proto) log.debug("Number of nodes in NX graph: {}".format( graph.number_of_nodes())) graph.__setattr__( 'name', output_model_name if output_model_name else model_proto.graph.name) # pylint: disable=no-member graph.graph['layout'] = 'NCHW' graph.graph['fw'] = 'onnx' graph.graph[ 'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3 graph.graph['cmd_params'] = argparse.Namespace( batch=None, data_type='float', disable_fusing=False, disable_gfusing=False, disable_resnet_optimization=False, enable_concat_optimization=False, extensions=mo_extensions, finegrain_fusing=None, framework='onnx', freeze_placeholder_with_value=None, generate_deprecated_IR_V2=False, input=None, input_model=None, input_shape=None, keep_shape_ops=False, log_level='ERROR', mean_scale_values={}, mean_values=(), model_name=None, move_to_preprocess=False, output=None, output_dir='.', placeholder_shapes=None, reverse_input_channels=False, scale=None, scale_values=(), silent=False, version=False, blobs_as_inputs=False, keep_quantize_ops_in_IR=False, generate_experimental_IR_V10=False) graph.graph['ir_version'] = 6 except Exception as e: raise Error( 'Cannot pre-process ONNX graph after reading from model file "{}". ' \ 'File is corrupt or has unsupported format. Details: {}. ' + refer_to_faq_msg(44), model_file_name, str(e) ) from e graph.check_empty_graph( 'protobuf2nx. It may happen due to problems with loaded model') extract_node_attrs( graph, lambda node: onnx_op_extractor( node, check_for_duplicates(onnx_op_extractors))) # --------------------------------- LOAD END ------------------------------------------------------ class_registration.apply_replacements( graph, class_registration.ClassType.FRONT_REPLACER) class_registration.apply_replacements( graph, class_registration.ClassType.MIDDLE_REPLACER) fuse_pad(graph) graph_clean_up_onnx(graph) for_graph_and_each_sub_graph_recursively( graph, convert_matmul_to_fully_connected) # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes mark_unfused_nodes(graph, False) # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift convert_batch_norm(graph) graph_clean_up_onnx(graph) # Converting ScaleShift layer to Mul->Add convert_scale_shift_to_mul_add(graph) graph_clean_up_onnx(graph) # Fusing the sequences of Mul/Add operations fuse_mul_add_sequence(graph) graph_clean_up_onnx(graph) # Fusing linear operation to Convolution fuse_linear_ops(graph) graph_clean_up_onnx(graph) grouped_convolutions_fusing(graph) graph_clean_up_onnx(graph) fuse_linear_ops(graph) graph_clean_up_onnx(graph) MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph) FakeQuantizeFuse().find_and_replace_pattern(graph) AddFakeQuantizeFuse().find_and_replace_pattern(graph) MulFakeQuantizeFuse().find_and_replace_pattern(graph) convert_muladd_to_scaleshift(graph) graph_clean_up_onnx(graph) graph_clean_up_onnx(graph) convert_add_or_mul_to_scaleshift(graph) # scale = 1 graph_clean_up_onnx(graph) fuse_pad(graph) graph_clean_up_onnx(graph) FuseReshapesSequence().find_and_replace_pattern(graph) RemoveRedundantReshapes().find_and_replace_pattern(graph) graph_clean_up_onnx(graph) pattern = EltwiseInputNormalize() pattern.find_and_replace_pattern(graph) merge_nodes_permutations(graph) permute_data_nodes_attrs(graph) permute_op_nodes_attrs(graph) graph_clean_up_onnx(graph) class_registration.apply_replacements( graph, class_registration.ClassType.BACK_REPLACER) for_graph_and_each_sub_graph_recursively(graph, remove_const_ops) CreateConstNodesReplacement().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, remove_output_ops) weights, xml_string = prepare_emit_ir(graph=graph, data_type=precision, output_dir=output_dir, output_model_name=output_model_name, meta_info={'unset': []}) return weights, xml_string
def driver(argv, input_model, output_model_name, output_dir): log_step(argv.steps, 'LOAD') meta_info = get_meta_info(argv) EltwiseChecker.enabled = False try: graph = load_kaldi_model(input_model) except Exception as e: raise Error('Model Optimizer is not able to parse Kaldi model {}. '.format(input_model) + refer_to_faq_msg(91)) from e graph.check_empty_graph('load_kaldi_nnet_model') graph.graph['cmd_params'] = argv graph.graph['fw'] = 'kaldi' if graph.graph['cmd_params'].generate_experimental_IR_V10: version = 10 else: version = 6 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version update_extractors_with_extensions(kaldi_type_extractors) extract_node_attrs(graph, lambda node: kaldi_extractor(node)) # --------------------------------- LOAD END ------------------------------------------------------ log_step(argv.steps, 'FRONT') ReplaceLSTMNodePattern().find_and_replace_pattern(graph) class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER) log_step(argv.steps, 'MIDDLE') graph = partial_infer(graph) ReplacePNormNodePattern().find_and_replace_pattern(graph) ReplaceMemoryOffsetNodePattern().find_and_replace_pattern(graph) ReplaceMemoryOffsetWithMemoryNodePattern().find_and_replace_pattern(graph) RemoveMemoryDuplicationPattern().find_and_replace_pattern(graph) MergeNeighborSplicePattern().find_and_replace_pattern(graph) RemoveUselessCropsPattern().find_and_replace_pattern(graph) RemoveIdentity().find_and_replace_pattern(graph) graph_clean_up(graph) AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph) ReplaceSpliceNodePattern().find_and_replace_pattern(graph) graph_clean_up(graph) # The order is intentional, firstly eliminate repeated, then remove redundant FuseRepeatedReshapes().find_and_replace_pattern(graph) EliminateRedundantReshape().find_and_replace_pattern(graph) graph_clean_up(graph) graph.check_empty_graph('partial_infer') if argv.counts: try: counts = read_counts_file(argv.counts) except Exception as e: raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) + refer_to_faq_msg(92)) from e apply_biases_to_last_layer(graph, counts) if argv.remove_output_softmax: RemoveLastSoftMaxPattern().find_and_replace_pattern(graph) graph_clean_up(graph) log.debug("After removing softmax") graph.print_graph_stat() log_step(argv.steps, 'BACK') LeakyReluToReluWithNegativeSlope().find_and_replace_pattern(graph) TransposeToPermute().find_and_replace_pattern(graph) DivideToEltwises().find_and_replace_pattern(graph) SubtractToEltwises().find_and_replace_pattern(graph) SimpleEltwiseToEltwiseOp().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, convert_matmul_to_fully_connected) # Intentionally after all transformations if argv.remove_memory: CutMemory().find_and_replace_pattern(graph) graph_clean_up(graph) ParameterToInput().find_and_replace_pattern(graph) KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph) ForceStrictPrecision().find_and_replace_pattern(graph) remove_const_ops(graph) CreateConstNodesReplacement().find_and_replace_pattern(graph) remove_output_ops(graph) log_step(argv.steps, 'EMIT') prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info) return 0
def driver(argv: argparse.Namespace, proto_file_name: str, model_file_name: str, output_model_name: str, output_dir: str, caffe_proto_path: str, mean_file: str = "", mean_file_offsets: tuple = None, custom_layers_mapping_path: str = None): log_step(argv.steps, 'LOAD') meta_info = get_meta_info(argv) caffe_pb2 = loader.import_caffe_pb2(caffe_proto_path) proto, model = loader.load_caffe_proto_model(caffe_pb2, proto_file_name, model_file_name) update_extractors_with_extensions( caffe_type_extractors, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.disable_flattening_optional_params if hasattr( argv, 'disable_flattening_optional_params') else False) try: graph, original_shapes = loader.caffe_pb_to_nx(proto, model) except ValueError as e: raise Error( 'Invalid prototxt file: value error {}. ' + refer_to_faq_msg(11), str(e)) from e log.debug("After caffe_pb_to_nx") graph.print_graph_stat() graph.check_empty_graph('load_caffe_proto_model') graph.__setattr__('proto_path', proto_file_name) graph.__setattr__('caffemodel_path', model_file_name) graph.__setattr__('name', getattr(proto, 'name', None) or output_model_name) graph.graph['layout'] = 'NCHW' graph.graph['cmd_params'] = argv graph.graph['fw'] = 'caffe' if graph.graph['cmd_params'].generate_experimental_IR_V10: version = 10 else: version = 6 graph.graph[ 'ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version custom_layers_map = custom_layers_mapping.load_layers_xml( custom_layers_mapping_path) custom_layers_mapping.update_extractors( caffe_type_extractors, custom_layers_map, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.enable_flattening_nested_params if hasattr( argv, 'enable_flattening_nested_params') else False) extract_node_attrs( graph, lambda node: caffe_extractor( node, check_for_duplicates(caffe_type_extractors))) # --------------------------------- LOAD END ------------------------------------------------------ log_step(argv.steps, 'FRONT') class_registration.apply_replacements( graph, class_registration.ClassType.FRONT_REPLACER) log_step(argv.steps, 'MIDDLE') class_registration.apply_replacements( graph, class_registration.ClassType.MIDDLE_REPLACER) # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes mark_unfused_nodes(graph, argv.finegrain_fusing) # need this pass even without fusing to convert scale with 2 inputs convert_scale_shift_to_mul_add(graph) graph_clean_up(graph) if not argv.disable_fusing: convert_bn_to_mul_add(graph) graph_clean_up(graph) fuse_mul_add_sequence(graph) graph_clean_up(graph) fuse_linear_ops(graph) graph_clean_up(graph) if not argv.disable_resnet_optimization: stride_optimization(graph) convert_muladd_to_scaleshift(graph) convert_matmul_to_fully_connected(graph) batch_norm_fuse(graph) convert_add_or_mul_to_scaleshift(graph) # scale = 1 graph_clean_up(graph) log.debug("After graph_cleanup") graph.print_graph_stat() if argv.reverse_input_channels: reverse_input_channels(graph) if argv.move_to_preprocess: move_scaleshift_to_preprocess(graph) graph_clean_up(graph) FuseReshapesSequence().find_and_replace_pattern(graph) RemoveRedundantReshapes().find_and_replace_pattern(graph) input_names = find_inputs(graph) mf = [] try: if mean_file and len(original_shapes) == 1: mf = loader.parse_mean(mean_file, original_shapes[input_names[0]], mean_file_offsets, caffe_pb2) elif mean_file: raise Error( 'Mean file for topologies with multiple inputs is not supported. ' + refer_to_faq_msg(9)) except ValueError as e: raise Error( 'Cannot load or process mean file: value error {}. ' + refer_to_faq_msg(10), str(e)) from e merge_nodes_permutations(graph) permute_data_nodes_attrs(graph) permute_op_nodes_attrs(graph) graph_clean_up(graph) log_step(argv.steps, 'BACK') class_registration.apply_replacements( graph, class_registration.ClassType.BACK_REPLACER) remove_const_ops(graph) CreateConstNodesReplacement().find_and_replace_pattern(graph) remove_output_ops(graph) log_step(argv.steps, 'EMIT') prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name, mean_data=mf, input_names=input_names, meta_info=meta_info) return 0
def tf2nx(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str, is_binary: bool): """ Convert TF GraphDef object to NetworkX representation. The resulting graph is still TF-specific and needs normalization passes to be applied. The specific TF structure assumes each GraphDef node is converted to a single NetworkX node, node id is an original TF node name, and edges go directly from one op to another op. """ meta_info = get_meta_info(argv) if argv.tensorflow_custom_layer_libraries: libraries = argv.tensorflow_custom_layer_libraries.split(',') for library in libraries: log.info('Loading library "{}" with custom operations'.format(library)) tf.load_op_library(library) graph_def, variables_values = load_tf_graph_def(graph_file_name=model_file_name, is_binary=is_binary, checkpoint=argv.input_checkpoint, user_output_node_names_list=argv.output, model_dir=argv.saved_model_dir, meta_graph_file=argv.input_meta_graph, saved_model_tags=argv.saved_model_tags) try: tf.import_graph_def(graph_def, name='') except: log.warning("TensorFlow post-processing of loaded model was unsuccessful. " "This is an optional step that Model Optimizer performs for any input model but it is not usually " "required for all models." "It likely means that the original model is ill-formed. " "Model Optimizer will continue converting this model.") log.debug("Number of nodes in graph_def: {}".format(len(graph_def.node))) # pylint: disable=no-member if argv.tensorboard_logdir: tensorboard.dump_for_tensorboard(graph_def, argv.tensorboard_logdir) update_extractors_with_extensions(tf_op_extractors) try: graph = protobuf2nx(graph_def) graph.__setattr__('name', output_model_name) # 'layout' parameter change may cause an issue in EltwiseInputReshape replacer # and convert_nhwc_to_nchw(graph) graph.graph['layout'] = 'NCHW' if argv.disable_nhwc_to_nchw else 'NHWC' graph.graph['cmd_params'] = argv graph.graph['fw'] = 'tf' graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5 graph.graph['variables_values'] = variables_values del variables_values graph = restore_edges(graph, get_tf_edges) graph = remove_control_dependency_inputs(graph) except Exception as e: raise Error( 'Cannot pre-process TensorFlow graph after reading from model file "{}". ' \ 'File is corrupt or has unsupported format. Details: {}. ' + refer_to_faq_msg(44), model_file_name, str(e) ) from e graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model') extract_node_attrs(graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors))) # --------------------------------- LOAD END ------------------------------------------------------ class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER) class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER) fuse_pad(graph) graph_clean_up_tf(graph) convert_matmul_to_fully_connected(graph) # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing)) # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift convert_batch_norm(graph) graph_clean_up_tf(graph) if not argv.disable_fusing: # Converting ScaleShift layer to Mul->Add for_graph_and_each_sub_graph_recursively(graph, convert_scale_shift_to_mul_add) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) # Fusing the sequences of Mul/Add operations for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) # Fusing linear operation to Convolution for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) if not argv.disable_gfusing: grouped_convolutions_fusing(graph) graph_clean_up_tf(graph) if not argv.disable_fusing: fuse_linear_ops(graph) graph_clean_up_tf(graph) # Converting Mul->Add to ScaleShift node for_graph_and_each_sub_graph_recursively(graph, convert_muladd_to_scaleshift_or_power) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) for_graph_and_each_sub_graph_recursively(graph, convert_mul_add_to_power) # Need to eliminate dead nodes before doing update_fully_connected_shapes # because update_fully_connected_shapes does partial inference and dead # nodes will lead to sporadic failures. for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) for_graph_and_each_sub_graph_recursively(graph, update_fully_connected_shapes) for_graph_and_each_sub_graph_recursively(graph, convert_mul_eltwise_to_leaky_relu) graph_clean_up_tf(graph) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) for_graph_and_each_sub_graph_recursively(graph, fuse_pad) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) for_graph_and_each_sub_graph_recursively(graph, convert_reshape) for_graph_and_each_sub_graph_recursively(graph, convert_squeeze) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) for_graph_and_each_sub_graph_recursively(graph, convert_add_or_mul_to_scaleshift) # scale = 1 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) if argv.reverse_input_channels: reverse_input_channels(graph) if argv.move_to_preprocess: move_scaleshift_to_preprocess(graph) graph_clean_up_tf(graph) fuse_sequence_of_reshapes(graph) pattern = EltwiseInputNormalize() pattern.find_and_replace_pattern(graph) conv_flatten_concat(graph) if argv.enable_concat_optimization: ConcatOptimization().find_and_replace_pattern(graph) LayoutChangeForConstantShapePaths().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) for_graph_and_each_sub_graph_recursively(graph, apply_nhwc_to_nchw_permutation) for_graph_and_each_sub_graph_recursively(graph, merge_nodes_permutations) for_graph_and_each_sub_graph_recursively(graph, permute_data_nodes_attrs) for_graph_and_each_sub_graph_recursively(graph, permute_op_nodes_attrs) for_graph_and_each_sub_graph_recursively(graph, repack_fully_connected_weights_nhwc_to_nchw) for_graph_and_each_sub_graph_recursively(graph, transpose_fully_connected_weights) for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf) class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER) for_graph_and_each_sub_graph_recursively(graph, remove_const_ops) CreateConstNodesReplacement().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, remove_output_ops) prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name, meta_info=meta_info) return 0
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, output_dir: str): meta_info = get_meta_info(argv) try: model_nodes, model_params, model_name, iteration_number = load_symbol_def( input_model, argv.input_symbol, argv.input, argv.nd_prefix_name, argv.pretrained_model_name, argv.legacy_mxnet_model) except (ValueError, mxnet.base.MXNetError) as e: raise FrameworkError( 'The following error happened while loading mxnet model {}: {}. ' + refer_to_faq_msg(53), input_model, str(e)) from e if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd: save_params_file(model_name, model_params._arg_params, model_params._aux_params, iteration_number) update_extractors_with_extensions(mxnet_op_extractors) graph = symbol2nx(model_nodes, model_params, argv.input) graph.check_empty_graph( 'symbol2nx. It may happen due to problems with loaded model') graph.__setattr__('name', output_model_name) graph.graph['layout'] = 'NCHW' graph.graph['cmd_params'] = argv graph.graph['fw'] = 'mxnet' graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5 extract_node_attrs(graph, mxnet_op_extractor) # --------------------------------- LOAD END ------------------------------------------------------ class_registration.apply_replacements( graph, class_registration.ClassType.FRONT_REPLACER) class_registration.apply_replacements( graph, class_registration.ClassType.MIDDLE_REPLACER) fuse_pad(graph) # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes mark_unfused_nodes(graph, argv.finegrain_fusing) # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence convert_batch_norm(graph) graph_clean_up(graph) if not argv.disable_fusing: # Converting ScaleShift layer to Mul->Add convert_scale_shift_to_mul_add(graph) graph_clean_up(graph) # Fusing the sequences of Mul/Add operations fuse_mul_add_sequence(graph) graph_clean_up(graph) # Fusing linear operation to Convolution fuse_linear_ops(graph) graph_clean_up(graph) if not argv.disable_resnet_optimization: stride_optimization(graph) fuse_pad(graph) # Converting Mul->Add to ScaleShift node convert_muladd_to_scaleshift_or_power(graph) graph_clean_up(graph) convert_mul_add_to_power(graph) graph_clean_up(graph) convert_add_or_mul_to_scaleshift(graph) # scale = 1 graph_clean_up(graph) if argv.reverse_input_channels: reverse_input_channels(graph) if argv.move_to_preprocess: move_scaleshift_to_preprocess(graph) graph_clean_up(graph) pattern = EltwiseInputNormalize() pattern.find_and_replace_pattern(graph) class_registration.apply_replacements( graph, class_registration.ClassType.BACK_REPLACER) for_graph_and_each_sub_graph_recursively(graph, remove_const_ops) CreateConstNodesReplacement().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, remove_output_ops) prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name, meta_info=meta_info) return 0