def load(self, graph: Graph): argv = graph.graph['cmd_params'] try: model_nodes, model_params, model_name, iteration_number = load_symbol_def( argv.input_model, argv.input_symbol, argv.input, argv.nd_prefix_name, argv.pretrained_model_name, argv.legacy_mxnet_model) except (ValueError, mxnet.base.MXNetError) as e: raise FrameworkError( 'The following error happened while loading mxnet model {}: {}. ' + refer_to_faq_msg(53), argv.input_model, str(e)) from e if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd: save_params_file(model_name, model_params._arg_params, model_params._aux_params, iteration_number) update_extractors_with_extensions(mxnet_op_extractors) symbol2nx(graph, model_nodes, model_params, argv.input) graph.check_empty_graph( 'symbol2nx. It may happen due to problems with loaded model') graph.graph['layout'] = 'NCHW' graph.graph['fw'] = 'mxnet' graph.graph[ 'feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3 extract_node_attrs(graph, mxnet_op_extractor) send_op_names_info('mxnet', graph) send_shapes_info('mxnet', graph)
def load(self, graph: Graph): argv = graph.graph['cmd_params'] try: load_kaldi_model(graph, argv.input_model) except Exception as e: raise Error('Model Optimizer is not able to parse Kaldi model {}. '.format(argv.input_model) + refer_to_faq_msg(91)) from e graph.check_empty_graph('load_kaldi_nnet_model') graph.graph['layout'] = 'NCHW' graph.graph['fw'] = 'kaldi' update_extractors_with_extensions(kaldi_type_extractors) extract_node_attrs(graph, lambda node: kaldi_extractor(node)) send_op_names_info('kaldi', graph) send_shapes_info('kaldi', graph)
def load(self, graph: Graph): argv = graph.graph['cmd_params'] caffe_pb2 = loader.import_caffe_pb2(argv.caffe_parser_path) proto, model = loader.load_caffe_proto_model(caffe_pb2, argv.input_proto, argv.input_model) update_extractors_with_extensions( caffe_type_extractors, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.disable_flattening_optional_params if hasattr( argv, 'disable_flattening_optional_params') else False) try: original_shapes = loader.caffe_pb_to_nx(graph, proto, model) except ValueError as e: raise Error( 'Invalid prototxt file: value error {}. ' + refer_to_faq_msg(11), str(e)) from e graph.check_empty_graph('load_caffe_proto_model') graph.__setattr__('proto_path', argv.input_proto) graph.__setattr__('caffemodel_path', argv.input_model) graph.__setattr__('name', getattr(proto, 'name', None) or argv.model_name) graph.graph['layout'] = 'NCHW' graph.graph['fw'] = 'caffe' graph.graph['original_shapes'] = original_shapes graph.graph['caffe_pb2'] = caffe_pb2 custom_layers_map = custom_layers_mapping.load_layers_xml(argv.k) custom_layers_mapping.update_extractors( caffe_type_extractors, custom_layers_map, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.enable_flattening_nested_params if hasattr( argv, 'enable_flattening_nested_params') else False) extract_node_attrs( graph, lambda node: caffe_extractor( node, check_for_duplicates(caffe_type_extractors))) send_op_names_info('caffe', graph) send_shapes_info('caffe', graph)
def load(self, graph: Graph): argv = graph.graph['cmd_params'] model_proto = load_onnx_model(argv.input_model) model_graph = model_proto.graph # pylint: disable=no-member # print(model_graph) # assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported" log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node))) log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input))) log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer))) log.debug( "Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer))) update_extractors_with_extensions(onnx_op_extractors) try: protobuf2nx(graph, model_proto) except Exception as e: raise Error( 'Cannot pre-process ONNX graph after reading from model file "{}". ' \ 'File is corrupt or has unsupported format. Details: {}. ' + refer_to_faq_msg(44), argv.input_model, str(e) ) from e log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes())) graph.__setattr__('name', argv.model_name if argv.model_name else model_proto.graph.name) # pylint: disable=no-member graph.graph['layout'] = 'NCHW' graph.graph['fw'] = 'onnx' graph.graph['feature_dim'] = 1 if hasattr(model_proto, 'opset_import'): graph.graph['fw_opset_version'] = model_proto.opset_import[0].version # pylint: disable=no-member else: graph.graph['fw_opset_version'] = None graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model') extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors))) send_op_names_info('onnx', graph) send_shapes_info('onnx', graph)
def load(self, graph: Graph): graph.check_empty_graph('loading from framework')
def apply_transform(graph: Graph, replacer_cls, **kwargs): """ Safely executes transform if it should be and validates graph after transform execution """ replacer = replacer_cls() replacement_id = 'REPLACEMENT_ID' if hasattr(replacer, 'replacement_id'): replacement_id = replacer.replacement_id if hasattr(replacer, 'enabled') and not replacer.enabled: log.info("Skip replacer {} (enabled = False)".format(replacer_cls)) return if hasattr(replacer, 'graph_condition') and \ not all([condition(graph) for condition in replacer.graph_condition]): log.info("Skip replacer {} (graph_condition not satisfied)".format( replacer_cls)) return log.debug("Run replacer {}".format(replacer_cls)) try: if hasattr(replacer, 'run_not_recursively') and replacer.run_not_recursively: replacer.find_and_replace_pattern(graph) else: for_graph_and_each_sub_graph_recursively( graph, replacer.find_and_replace_pattern) if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up: for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) if hasattr(replacer, 'force_shape_inference') and replacer.force_shape_inference: shape_inference(graph) if hasattr(replacer, 'run_not_recursively') and replacer.run_not_recursively: graph.check_empty_graph(replacer_cls) graph.check_shapes_consistency() else: for_graph_and_each_sub_graph_recursively( graph, lambda _: graph.check_empty_graph(replacer_cls)) for_graph_and_each_sub_graph_recursively( graph, lambda _: graph.check_shapes_consistency()) except Error as err: raise Error( 'Exception occurred during running replacer "{}" ({}): {}'.format( replacement_id, replacer_cls, str(err).replace('[REPLACEMENT_ID]', replacement_id), )) from err except FrameworkError as err: raise FrameworkError('{}'.format(str(err))) from err except Exception as err: raise Exception( 'Exception occurred during running replacer "{} ({})": {}'.format( replacement_id, replacer_cls, str(err).replace('[REPLACEMENT_ID]', replacement_id), )) from err
def load(self, graph: Graph): argv = graph.graph['cmd_params'] if argv.tensorflow_custom_layer_libraries: libraries = argv.tensorflow_custom_layer_libraries.split(',') for library in libraries: log.info('Loading library "{}" with custom operations'.format( library)) tf_v1.load_op_library(library) graph_def, variables_values, framework, inputs_outputs_order = load_tf_graph_def( graph_file_name=argv.input_model, is_binary=not argv.input_model_is_text, checkpoint=argv.input_checkpoint, user_output_node_names_list=argv.output, model_dir=argv.saved_model_dir, meta_graph_file=argv.input_meta_graph, saved_model_tags=argv.saved_model_tags) if inputs_outputs_order is not None and isinstance( inputs_outputs_order, tuple): graph.inputs_order = inputs_outputs_order[0] graph.outputs_order = inputs_outputs_order[1] send_framework_info(framework) try: tf_v1.import_graph_def(graph_def, name='') except: log.warning( "TensorFlow post-processing of loaded model was unsuccessful. " "This is an optional step that Model Optimizer performs for any input model but it is not usually " "required for all models. " "It likely means that the original model is ill-formed. " "Model Optimizer will continue converting this model.") log.debug("Number of nodes in graph_def: {}".format(len( graph_def.node))) # pylint: disable=no-member if argv.tensorboard_logdir: tensorboard_util.dump_for_tensorboard(graph_def, argv.tensorboard_logdir) update_extractors_with_extensions(tf_op_extractors) try: protobuf2nx(graph, graph_def) except Exception as e: raise Error( 'Cannot pre-process TensorFlow graph after reading from model file "{}". ' \ 'File is corrupt or has unsupported format. Details: {}. ' + refer_to_faq_msg(44), argv.model_name, str(e) ) from e graph.__setattr__('name', argv.model_name) # 'layout' parameter change may cause an issue in EltwiseInputReshape replacer # and convert_nhwc_to_nchw(graph) graph.graph['layout'] = 'NCHW' if argv.disable_nhwc_to_nchw else 'NHWC' graph.graph['fw'] = 'tf' graph.graph['variables_values'] = variables_values del variables_values used_tensors = restore_edges(graph, get_tf_edges) # Tensor names information corresponding to a node is stored on outgoing edges. # As output nodes do not have outgoing edges, fake outputs are required. In the following code # for each output Identity node is added, and tensor name for the output is kept # on (output, fake output) edge. After Result nodes adding transformation fake outputs # are deleted from graph. add_outputs_identity( graph, graph.nodes - used_tensors, lambda g, output, fake_node_name: g.add_edges_from( [create_tf_edge(output, fake_node_name, 0)])) remove_control_dependency_inputs(graph) graph.check_empty_graph( 'protobuf2nx. It may happen due to problems with loaded model') extract_node_attrs( graph, lambda node: tf_op_extractor( node, check_for_duplicates(tf_op_extractors))) # try to detect layout from the nodes of the graph. If there are no convolution nodes in N(D)HWC layout then we # consider that the graph is in NCHW layout and no layout conversion should be performed if not argv.disable_nhwc_to_nchw and not graph_or_sub_graph_has_nhwc_ops( graph): if not argv.silent: log.debug('disable_nhwc_to_nchw" was automatically enabled.') for_graph_and_each_sub_graph_recursively( graph, update_cmd_params_and_layout) send_op_names_info(framework, graph) send_shapes_info(framework, graph)