def strip_meta_graph(meta_graph_def, node_names, var_names): node_names = node_names[:] collections = meta_graph_def.collection_def # Look for matching variable names and initializers and keep them too. var_def = variable_pb2.VariableDef() for var_col_name in ["variables", "trainable_variables"]: var_def_bs = collections[var_col_name].bytes_list.value for var_def_b in var_def_bs: var_def.ParseFromString(var_def_b) if var_def.variable_name not in var_names: # TODO(adamb) Should remove variable from collection. continue node_names.append(var_def.initializer_name) wc_def = control_flow_pb2.WhileContextDef() wc_values = collections["while_context"].bytes_list.value for wc_ix in range(len(wc_values) - 1, -1, -1): wc_bytes = wc_values[wc_ix] wc_def.ParseFromString(wc_bytes) unused = True wc_pivot_name = wc_def.pivot_name for name in node_names: if name.startswith(wc_pivot_name): unused = False break if unused: del wc_values[wc_ix] graph_def = meta_graph_def.graph_def eprint("only keeping", node_names, "from", [n.name for n in graph_def.node]) graph_def = graph_util.extract_sub_graph(graph_def, node_names) meta_graph_def.graph_def.CopyFrom(graph_def)
def strip_pruning_vars_fn(input_graph_def, output_node_names): """Removes mask variable from the graph. Replaces the masked_weight tensor with element-wise multiplication of mask and the corresponding weight variable. Args: input_graph_def: A GraphDef in which the variables have been converted to constants. This is typically the output of tf.graph_util.convert_variables_to_constant() output_node_names: List of name strings for the result nodes of the graph Returns: A GraphDef in which pruning-related variables have been removed """ masked_weights_dict = _get_masked_weights(input_graph_def) pruned_graph_def = graph_pb2.GraphDef() # Replace masked_weight with a const op containing the # result of tf.multiply(mask,weight) for node in input_graph_def.node: output_node = node_def_pb2.NodeDef() if 'masked_weight' in node.name: output_node.op = 'Const' output_node.name = node.name dtype = node.attr['T'] data = masked_weights_dict[node.name] output_node.attr['dtype'].CopyFrom(dtype) output_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto(data))) else: output_node.CopyFrom(node) pruned_graph_def.node.extend([output_node]) # Remove stranded nodes: mask and weights return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type. Returns: A GraphDef with all unnecessary ops removed. """ # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type. Returns: A GraphDef with all unnecessary ops removed. """ # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom( tf.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
def __init__(self, session, graph_def, output_node_names, variable_names_allowlist=None, variable_names_denylist=None): graph_def = graph_util.extract_sub_graph(graph_def, output_node_names) super(_SessionConverterData, self).__init__( graph_def, variable_names_allowlist=variable_names_allowlist, variable_names_denylist=variable_names_denylist) nodes_to_convert = [] tensor_names_to_convert = [] for node in self.graph_def.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: tensor_name = node.name if not self._should_convert(tensor_name): continue if node.op == "VarHandleOp": tensor_name = tensor_name + "/Read/ReadVariableOp" nodes_to_convert.append(node) tensor_names_to_convert.append(tensor_name + ":0") if tensor_names_to_convert: converted_tensors = session.run(tensor_names_to_convert) for node, tensor_value in zip(nodes_to_convert, converted_tensors): self._tensor_data[node.name] = _TensorData( numpy=tensor_value, dtype=node.attr["dtype"].type, index=None)
def test_remove_unneeded_nodes(self): a_constant_name = "a_constant" b_constant_name = "b_constant" a_check_name = "a_check" b_check_name = "b_check" a_identity_name = "a_identity" b_identity_name = "b_identity" add_name = "add" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant]) a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, [a_constant_name]) graph_def.node.extend([a_check_node]) a_identity_node = quantize_graph.create_node("Identity", a_identity_name, [a_constant_name, "^" + a_check_name]) graph_def.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant]) b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, [b_constant_name]) graph_def.node.extend([b_check_node]) b_identity_node = quantize_graph.create_node("Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) graph_def.node.extend([b_identity_node]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_identity_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) graph_def.node.extend([add_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant]) add_node = quantize_graph.create_node("Add", add_name, [a_constant_name, b_constant_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) expected_output.node.extend([add_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [add_name]) output = rewriter.remove_unneeded_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output)
def testExtractSubGraph(self): graph_def = graph_pb2.GraphDef() n1 = graph_def.node.add() n1.name = "n1" n1.input.extend(["n5"]) n2 = graph_def.node.add() n2.name = "n2" # Take the first output of the n1 node as the input. n2.input.extend(["n1:0"]) n3 = graph_def.node.add() n3.name = "n3" # Add a control input (which isn't really needed by the kernel, but # rather to enforce execution order between nodes). n3.input.extend(["^n2"]) n4 = graph_def.node.add() n4.name = "n4" # It is fine to have a loops in the graph as well. n5 = graph_def.node.add() n5.name = "n5" n5.input.extend(["n1"]) sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"]) self.assertEqual("n1", sub_graph.node[0].name) self.assertEqual("n2", sub_graph.node[1].name) self.assertEqual("n3", sub_graph.node[2].name) self.assertEqual("n5", sub_graph.node[3].name)
def save_graph_only(sess, output_file_path, output_node_names, as_text=False): """Save a small version of the graph based on a session and the output node names.""" for node in sess.graph_def.node: node.device = '' graph_def = graph_util.extract_sub_graph(sess.graph_def, output_node_names) output_dir, output_filename = os.path.split(output_file_path) graph_io.write_graph(graph_def, output_dir, output_filename, as_text=as_text)
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A `GraphDef` with all unnecessary ops removed. Raises: ValueError: If any element in `input_node_names` refers to a tensor instead of an operation. KeyError: If any element in `input_node_names` is not found in the graph. """ for name in input_node_names: if ":" in name: raise ValueError( f"Name '{name}' appears to refer to a Tensor, not an " "Operation.") # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). not_found = {name for name in input_node_names} inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: not_found.remove(node.name) placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if isinstance(placeholder_type_enum, list): input_node_index = input_node_names.index(node.name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=placeholder_type_enum[input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) if "shape" in node.attr: placeholder_node.attr["shape"].CopyFrom(node.attr["shape"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) if not_found: raise KeyError( f"The following input nodes were not found: {not_found}.") output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
def parse_pb(file_or_path, output_nodes=None): """ arguments ========= - file_or_path: a file object or a path string of the pb file - output_nodes: list of output node names returns ======= - nodes: mapping from node name to its NodeDef object """ if sys.version_info.major < 3: file_type = (file, io.IOBase) # pylint: disable=E0602 else: file_type = io.IOBase if isinstance(file_or_path, str): fid = open(file_or_path, "rb") elif isinstance(file_or_path, file_type): fid = file_or_path else: raise ValueError( "`file_or_path` has to be either file object or path string") # load pb file graph_def = tf.GraphDef() graph_def.ParseFromString(fid.read()) fid.close() if output_nodes: sub_graph_def = graph_util.extract_sub_graph(graph_def, output_nodes) else: sub_graph_def = graph_def return dict((node.op, node) for node in sub_graph_def.node) # pylint: disable=E1101
def test_keep_control_edges(self): no_op_name = "no_op" a_constant_name = "a_constant" b_constant_name = "b_constant" a_check_name = "a_check" b_check_name = "b_check" a_identity_name = "a_identity" b_identity_name = "b_identity" add_name = "add" graph_def = graph_pb2.GraphDef() no_op = quantize_graph.create_node("NoOp", no_op_name, []) graph_def.node.extend([no_op]) a_constant = quantize_graph.create_constant_node( a_constant_name, value=1, dtype=dtypes.float32, shape=[]) graph_def.node.extend([a_constant]) a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, [a_constant_name]) graph_def.node.extend([a_check_node]) a_identity_node = quantize_graph.create_node( "Identity", a_identity_name, [a_constant_name, "^" + a_check_name, "^" + no_op_name]) graph_def.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=1, dtype=dtypes.float32, shape=[]) graph_def.node.extend([b_constant]) b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, [b_constant_name]) graph_def.node.extend([b_check_node]) b_identity_node = quantize_graph.create_node( "Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) graph_def.node.extend([b_identity_node]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_identity_name]) quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32) graph_def.node.extend([add_node]) expected_output = graph_pb2.GraphDef() no_op = quantize_graph.create_node("NoOp", no_op_name, []) expected_output.node.extend([no_op]) a_constant = quantize_graph.create_constant_node( a_constant_name, value=1, dtype=dtypes.float32, shape=[]) expected_output.node.extend([a_constant]) a_identity_node = quantize_graph.create_node( "Identity", a_identity_name, [a_constant_name, "^" + no_op_name]) expected_output.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=1, dtype=dtypes.float32, shape=[]) expected_output.node.extend([b_constant]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_constant_name]) quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32) expected_output.node.extend([add_node]) expected_output.versions.CopyFrom(graph_def.versions) expected_output.library.CopyFrom(graph_def.library) output = graph_util.remove_training_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output)
def save_pb(self, pb_path): extracted_graph = graph_util.extract_sub_graph( self.sess.graph_def, [self.name + '/light_state', self.name + '/light_position']) constant_graph = graph_util.convert_variables_to_constants( self.sess, self.sess.graph_def, [n.name for n in extracted_graph.node]) with tf.gfile.FastGFile(pb_path, mode='wb') as f: f.write(constant_graph.SerializeToString())
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A `GraphDef` with all unnecessary ops removed. Raises: ValueError: If any element in `input_node_names` refers to a tensor instead of an operation. KeyError: If any element in `input_node_names` is not found in the graph. """ for name in input_node_names: if ":" in name: raise ValueError("Name '%s' appears to refer to a Tensor, " "not a Operation." % name) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). not_found = {name for name in input_node_names} inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: not_found.remove(node.name) placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if isinstance(placeholder_type_enum, list): input_node_index = input_node_names.index(node.name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum[ input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[ "_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) if not_found: raise KeyError("The following input nodes were not found: %s\n" % not_found) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
def run(self): # Normalize feeds and fetch fetch = self.fetch.split(",") if isinstance(self.fetch, str) else self.fetch feeds = self.feeds.split(",") if isinstance(self.feeds, str) else self.feeds # Find latest SavedModel export in path_saved_model subdirs = [ str(path) for path in Path(self.path_saved_model).iterdir() if path.is_dir() and "temp" not in str(path) ] latest = str(sorted(subdirs)[-1]) LOGGER.info(f"Using SavedModel {latest}") # Reload SavedModel Graph, optimize and export with tf.compat.v1.Session(graph=tf.Graph()) as sess: meta_graph_def = tf.compat.v1.saved_model.loader.load(sess, ["serve"], latest) graph_def = meta_graph_def.graph_def # Add table initializer if present, or create it if INIT_ALL_TABLES in {node.name for node in graph_def.node}: fetch.append(INIT_ALL_TABLES) else: table_initializers = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS) if table_initializers: LOGGER.info(f"Adding {INIT_ALL_TABLES} Node to the graph") table_init_op = tf.group(*table_initializers, name=INIT_ALL_TABLES) node_def = table_init_op.node_def graph_def.node.append(node_def) fetch.append(INIT_ALL_TABLES) # Rename nodes graph_def = rename_nodes(graph_def, self.new_names) # Setup (create / remove) placeholders graph_def = make_placeholders(graph_def, feeds) # Keep only part of the graph that produces tensor 'fetch' graph_def = extract_sub_graph(graph_def, fetch) # Replace variables by constants graph_def = freeze_graph_with_def_protos( input_graph_def=graph_def, input_saver_def=None, input_checkpoint=None, output_node_names=",".join(fetch), restore_op_name=None, filename_tensor_name=None, output_graph=None, clear_devices=True, initializer_nodes=None, variable_names_blacklist=",".join(self.blacklisted_variables), input_saved_model_dir=latest, saved_model_tags=["serve"], ) tf.io.write_graph(graph_def, logdir=self.path_optimized_model, name=self.graph_name, as_text=False) LOGGER.info(f"Optimized Model successfully exported to {self.path_optimized_model}/{self.graph_name}")
def tf_optimize(sess, input_names, output_names, graph_def): transforms = [ "remove_nodes(op=Identity, op=CheckNumerics)", "fold_batch_norms", "fold_old_batch_norms" # fails: "fold_constants(ignore_errors=true)", ] needed_names = input_names + output_names graph_def = graph_util.extract_sub_graph(graph_def, needed_names) graph_def = TransformGraph(graph_def, input_names, output_names, transforms) return graph_def
def tf_optimize(sess, inputs, outputs, graph_def): """Optimize tensorflow graph for inference.""" transforms = [ "fold_constants(ignore_errors=true)", "fold_batch_norms", "fold_old_batch_norms", ] needed_names = [utils.node_name(i) for i in inputs] + [utils.node_name(i) for i in outputs] graph_def = graph_util.extract_sub_graph(graph_def, needed_names) graph_def = TransformGraph(graph_def, inputs, outputs, transforms) return graph_def
def __init__(self, meta_file, checkpoint_file, dest_nodes, inputShape=None, in_nodes=None): super(TensorflowParser, self).__init__() # load model files into TensorFlow graph if meta_file: model = TensorflowParser._load_meta(meta_file) if checkpoint_file: self.ckpt_data = TensorflowParser._load_weights(checkpoint_file) self.weight_loaded = True # extract subgraph using in_nodes and dest_nodes if in_nodes != None and inputShape != None: from tensorflow.python.tools import strip_unused_lib from tensorflow.python.framework import dtypes from tensorflow.python.platform import gfile input_node_names = in_nodes.split(',') output_node_names = dest_nodes.split(',') model = strip_unused_lib.strip_unused( input_graph_def=model, input_node_names=input_node_names, output_node_names=output_node_names, placeholder_type_enum=dtypes.float32.as_datatype_enum) input_list = [None] for i in range(len(inputShape)): input_list.append(tensorflow.Dimension(inputShape[i])) tensor_input = tensorflow.TensorShape(input_list) # Build network graph self.tf_graph = TensorflowGraph(model) for node in self.tf_graph.model.node: if node.name in input_node_names: node.attr['shape'].list.shape.extend( [tensor_input.as_proto()]) node.attr['_output_shapes'].list.shape.pop( ) #unknown_rank pop node.attr['_output_shapes'].list.shape.extend( [tensor_input.as_proto()]) # extract subgraph using dest_nodes elif dest_nodes != None: from tensorflow.python.framework.graph_util import extract_sub_graph model = extract_sub_graph(model, dest_nodes.split(',')) self.tf_graph = TensorflowGraph(model) else: self.tf_graph = TensorflowGraph(model) self.tf_graph.build()
def main(_): tf.logging.set_verbosity(tf.logging.INFO) tf.keras.backend.set_learning_phase(0) model = build_model() model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy()) graph_def = K.get_session().graph.as_graph_def() graph_def = graph_util.extract_sub_graph(graph_def, [FLAGS.output_nodes]) tf.train.write_graph(graph_def, FLAGS.save_dir, FLAGS.graph_filename, as_text=True) print("Finish export inference graph: {}".format(FLAGS.save_dir))
def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes=None): super(TensorflowParser, self).__init__() if meta_file: model = TensorflowParser._load_meta(meta_file) if checkpoint_file: self.ckpt_data = TensorflowParser._load_weights(checkpoint_file) self.weight_loaded = True if dest_nodes != None: from tensorflow.python.framework.graph_util import extract_sub_graph model = extract_sub_graph(model, dest_nodes.split(',')) self.tf_graph = TensorflowGraph(model) self.tf_graph.build()
def load_meta(self): """ Load a tensorflow meta file from disk Returns: model: A tensorflow protobuf file """ from tensorflow.core.protobuf import meta_graph_pb2 meta_graph = meta_graph_pb2.MetaGraphDef() self.load_protobuf_from_file(meta_graph, self._tf_model_prefix + '.meta') graph = meta_graph.graph_def if self._dest_nodes != None: from tensorflow.python.framework.graph_util import extract_sub_graph graph = extract_sub_graph(graph, self._dest_nodes.split(',')) print ("Tensorflow model file [%s] loaded successfully." % self._tf_model_prefix) return graph
def tf_optimize(inputs, outputs, graph_def, fold_constant=None): """Optimize tensorflow graph for inference.""" transforms = [] if fold_constant: transforms.extend([ "fold_constants(ignore_errors=true)", "remove_attribute(attribute_name=_class)", # remove node colocation attributes ]) transforms.extend([ "fold_batch_norms", "fold_old_batch_norms", ]) needed_names = [utils.node_name(i) for i in inputs] + [utils.node_name(i) for i in outputs] graph_def = graph_util.extract_sub_graph(graph_def, needed_names) graph_def = TransformGraph(graph_def, inputs, outputs, transforms) return graph_def
def extract_sub_graph(input_path, output_path=None, dest_nodes=None): if not output_path: output_path = append_file_name_suffix(input_path, "sub") logging.info("load from %s", input_path) graph_def = load_graph_def_from_pb(input_path) logging.info("\ttotal node = %s", len(graph_def.node)) if dest_nodes: dest_nodes = dest_nodes.split(',') else: _, dest_nodes = get_graph_def_io_nodes(graph_def) graph_def = graph_util.extract_sub_graph(graph_def, dest_nodes) logging.info("save to %s", output_path) logging.info("\ttotal node = %s", len(graph_def.node)) save_graph_def(graph_def, output_path)
def run(self): # Find latest SavedModel export in path_saved_model subdirs = [ str(path) for path in Path(self.path_saved_model).iterdir() if path.is_dir() and "temp" not in str(path) ] latest = str(sorted(subdirs)[-1]) LOGGER.info(f"Using SavedModel {latest}") # Reload SavedModel Graph, optimize and export with tf.Session(graph=tf.Graph()) as sess: graph = tf.saved_model.loader.load(sess, ["serve"], latest) graph_def = graph.graph_def # Rename nodes graph_def = rename_nodes(graph_def, self.new_names) # Setup (create / remove) placeholders graph_def = make_placeholders(graph_def, self.feeds) # Keep only part of the graph that produces tensor 'fetch' graph_def = extract_sub_graph(graph_def, [self.fetch]) # Replace variables by constants graph_def = freeze_graph_with_def_protos( input_graph_def=graph_def, input_saver_def=None, input_checkpoint=None, output_node_names=self.fetch, restore_op_name=None, filename_tensor_name=None, output_graph=None, clear_devices=True, initializer_nodes=None, variable_names_blacklist=",".join(self.blacklisted_variables), input_saved_model_dir=latest, saved_model_tags=["serve"], ) tf.io.write_graph(graph_def, logdir=self.path_optimized_model, name=self.graph_name, as_text=False) LOGGER.info( f"Online KNN successfully exported to {self.path_optimized_model}/{self.graph_name}" )
def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes = None): super(TensorflowParser, self).__init__() # load model files into TensorFlow graph if meta_file: model = TensorflowParser._load_meta(meta_file) if checkpoint_file: self.ckpt_data = TensorflowParser._load_weights(checkpoint_file) self.weight_loaded = True if dest_nodes != None: from tensorflow.python.framework.graph_util import extract_sub_graph model = extract_sub_graph(model, dest_nodes.split(',')) # Build network graph self.tf_graph = TensorflowGraph(model) self.tf_graph.build()
def __init__(self, input_args, dest_nodes = None): super(TensorflowParser, self).__init__() # load model files into Keras graph from six import string_types as _string_types if isinstance(input_args, _string_types): model = TensorflowParser._load_meta(input_args) elif isinstance(input_args, tuple): model = TensorflowParser._load_meta(input_args[0]) self.ckpt_data = TensorflowParser._load_weights(input_args[1]) self.weight_loaded = True if dest_nodes != None: from tensorflow.python.framework.graph_util import extract_sub_graph model = extract_sub_graph(model, dest_nodes.split(',')) # Build network graph self.tf_graph = TensorflowGraph(model) self.tf_graph.build()
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom( tf.AttrValue(type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph( inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def tf_optimize(sess, inputs, outputs, graph_def): # print("tf_optimize begin") """Optimize tensorflow graph for inference.""" transforms = [ "fold_constants(ignore_errors=true)", "fold_batch_norms", "fold_old_batch_norms", ] # TODO 这俩 在 研究 研究 needed_names = [utils.node_name(i) for i in inputs] + [utils.node_name(i) for i in outputs] print("---------------needed_names:", needed_names) graph_def = graph_util.extract_sub_graph(graph_def, needed_names) print("extract_sub_graph done") graph_def = TransformGraph(graph_def, inputs, outputs, transforms) print("TransformGraph done") return graph_def
def parse_pb(file_or_path, output_nodes=None): """ Arguments ========= - file_or_path: a file object or a path string of the pb file - output_nodes: list of output node names Returns ======= - ops_info <dict>: a dict with information neccessary for building context in uTensor - ops_topo <list>: list of op node names in topological sorted order - output_nodes <list>: list of output node names """ if sys.version_info.major < 3: file_type = (file, io.IOBase) # pylint: disable=E0602 else: file_type = io.IOBase if isinstance(file_or_path, str): fid = open(file_or_path, "rb") elif isinstance(file_or_path, file_type): fid = file_or_path else: raise ValueError( "`file_or_path` has to be either file object or path string") # load pb file graph_def = tf.GraphDef() graph_def.ParseFromString(fid.read()) fid.close() if output_nodes is not None: graph_def = graph_util.extract_sub_graph(graph_def, output_nodes) ops_info, ops_topo, output_nodes = _parse_graph_def( graph_def, output_nodes) return ops_info, ops_topo, output_nodes
def strip_pruning_vars_fn(input_graph_def, output_node_names): """Removes mask variable from the graph. Replaces the masked_weight tensor with element-wise multiplication of mask and the corresponding weight variable. Args: input_graph_def: A GraphDef in which the variables have been converted to constants. This is typically the output of tf.graph_util.convert_variables_to_constant() output_node_names: List of name strings for the result nodes of the graph Returns: A GraphDef in which pruning-related variables have been removed """ masked_weights_dict = _get_masked_weights(input_graph_def) pruned_graph_def = graph_pb2.GraphDef() # Replace masked_weight with a const op containing the # result of tf.multiply(mask,weight) for node in input_graph_def.node: output_node = node_def_pb2.NodeDef() if 'masked_weight' in node.name: output_node.op = 'Const' output_node.name = node.name dtype = node.attr['T'] data = masked_weights_dict[node.name] output_node.attr['dtype'].CopyFrom(dtype) output_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data))) else: output_node.CopyFrom(node) pruned_graph_def.node.extend([output_node]) # Remove stranded nodes: mask and weights return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
def testExtractSubGraphWithInvalidDestNodes(self): graph_def = graph_pb2.GraphDef() n1 = graph_def.node.add() n1.name = "n1" with self.assertRaisesRegexp(TypeError, "must be a list"): graph_util.extract_sub_graph(graph_def, "n1")
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
def __init__(self, meta_file, checkpoint_file, dest_nodes, inputShape=None, in_nodes=None): super(TensorflowParser, self).__init__() # load model files into TensorFlow graph if meta_file: model = TensorflowParser._load_meta(meta_file) if checkpoint_file: self.ckpt_data = TensorflowParser._load_weights(checkpoint_file) self.weight_loaded = True # extract subgraph using in_nodes and dest_nodes if in_nodes != None and inputShape != None: from tensorflow.python.tools import strip_unused_lib from tensorflow.python.framework import dtypes from tensorflow.python.platform import gfile model = strip_unused_lib.strip_unused( input_graph_def=model, input_node_names=in_nodes, output_node_names=dest_nodes, placeholder_type_enum=dtypes.float32.as_datatype_enum) input_list = [None] for i in range(len(inputShape)): input_list.append(tensorflow.Dimension(inputShape[i])) tensor_input = tensorflow.TensorShape(input_list) # Build network graph self.tf_graph = TensorflowGraph(model) for node in self.tf_graph.model.node: if node.name in in_nodes: node.attr['shape'].shape.CopyFrom(tensor_input.as_proto()) node.attr['_output_shapes'].list.shape.pop( ) #unknown_rank pop node.attr['_output_shapes'].list.shape.extend( [tensor_input.as_proto()]) # extract subgraph using dest_nodes elif dest_nodes != None: from tensorflow.python.framework.graph_util import extract_sub_graph model = extract_sub_graph(model, dest_nodes) self.tf_graph = TensorflowGraph(model) else: self.tf_graph = TensorflowGraph(model) # Graph Transform transforms = ["fold_constants(ignore_errors=true)"] # Get input node name if not in_nodes: in_nodes = [] for node in model.node: if node.op == 'Placeholder': in_nodes.append(node.name) transformed_graph_def = TransformGraph(model, in_nodes, dest_nodes, transforms) in_type_list = {} in_shape_list = {} for n in transformed_graph_def.node: if n.name in in_nodes: in_type_list[n.name] = n.attr['dtype'].type in_node_shape = n.attr['shape'].shape in_node_shape_str = self._shapeToStr(in_node_shape) in_shape_list[n.name] = in_node_shape_str dtype = tensorflow.float32 with tensorflow.Graph().as_default() as g: input_map = {} for in_node in in_nodes: if in_type_list[in_node] == 1 or in_type_list[in_node] == 0: dtype = tensorflow.float32 elif in_type_list[in_node] == 3: dtype = tensorflow.int32 elif in_type_list[in_node] == 10: dtype = tensorflow.bool x = tensorflow.placeholder(dtype, shape=in_shape_list[in_node]) input_map[in_node] = x tensorflow.import_graph_def(transformed_graph_def, name='', input_map=input_map) with tensorflow.Session(graph=g) as sess: tempdir = tempfile.mkdtemp() meta_graph_def = tensorflow.train.export_meta_graph( filename=os.path.join(tempdir, 'my-model.meta')) model = meta_graph_def.graph_def shutil.rmtree(tempdir) self.tf_graph = TensorflowGraph(model) self.tf_graph.build() process_graph(self.tf_graph, self.ckpt_data)
def strip_unused(input_graph_def, input_tensor_names, output_tensor_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_tensor_names: A list of the nodes we use as inputs. output_tensor_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A `GraphDef` with all unnecessary ops removed. and a map containing the old input names to the new input names Raises: ValueError: If any element in `input_node_names` refers to a tensor instead of an operation. KeyError: If any element in `input_node_names` is not found in the graph. """ for name in input_tensor_names: if ":" not in name: raise ValueError("Input '%s' appears to refer to a Operation, " "not a Tensor." % name) old2new = {} # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). not_found = {name for name in input_tensor_names} input_node_names = {name.split(":")[0] for name in input_tensor_names} output_node_names = list( {name.split(":")[0] for name in output_tensor_names}) inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name not in input_node_names: for i in range(len(node.input)): if _append_port(node.input[i]) in input_tensor_names: old_name = _append_port(node.input[i]) not_found.remove(old_name) new_input_name = node.input[i].replace(":", "_") placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = new_input_name if isinstance(placeholder_type_enum, list): input_node_index = input_tensor_names.index(old_name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=placeholder_type_enum[input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue( type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) node.input[i] = new_input_name old2new[old_name] = new_input_name + ":0" inputs_replaced_graph_def.node.extend([placeholder_node]) inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) if not_found: raise KeyError("The following input nodes were not found: %s\n" % not_found) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def, old2new
def parse_pb(file_or_path, output_nodes=None) -> (dict, list): """ Arguments ========= - file_or_path: a file object or a path string of the pb file - output_nodes: list of output node names Returns ======= - graph_info <defaultdict>: a dict with information neccessary for building context in uTensor - layers <list>: list of layer which is a list of operation names in the graph Note ==== graph_info example: { 'my_const': { "input_tensor": set([]), "output_tensor": set(["my_const:0"]) "output_content": {"my_const:0": 3.14}, "op_type": "Const" }, 'myop': { "input_tensor": set(["input1:0", "input2:0"]), "output_tensor": set(["my_op:0", "my_op:1"]), "output_content": {}, "op_type": "MyOp" }, ... } layers example: `bottom` <--------> `top` foo - \\ tar - - var / bar - the return list, layers, will be [['foo', 'bar'], ['tar'], ['var']] That is, layers[0] is the bottom layer of the graph, layers[1] is the second bottom layer of the graph, so on and so forth """ if sys.version_info.major < 3: file_type = (file, io.IOBase) # pylint: disable=E0602 else: file_type = io.IOBase if isinstance(file_or_path, str): fid = open(file_or_path, "rb") elif isinstance(file_or_path, file_type): fid = file_or_path else: raise ValueError( "`file_or_path` has to be either file object or path string") # load pb file graph_def = tf.GraphDef() graph_def.ParseFromString(fid.read()) fid.close() if output_nodes is not None: graph_def = graph_util.extract_sub_graph(graph_def, output_nodes) graph_info, layers = _parse_graph_def(graph_def) return graph_info, layers
def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" a_constant_max_name = "a_constant_max" a_dequantize_name = "a_dequantize" a_quantize_name = "a_quantize" b_constant_name = "b_constant" b_constant_min_name = "b_constant_min" b_constant_max_name = "b_constant_max" b_dequantize_name = "b_dequantize" b_quantize_name = "b_quantize" mat_mul_name = "mat_mul" graph_def = graph_pb2.GraphDef() a_constant = quantize_graph.create_constant_node( a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) graph_def.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=dtypes.float32, shape=[]) graph_def.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=dtypes.float32, shape=[]) graph_def.node.extend([a_constant_max]) a_dequantize_node = quantize_graph.create_node( "Dequantize", a_dequantize_name, [a_constant_name, a_constant_min_name, a_constant_max_name]) quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8) graph_def.node.extend([a_dequantize_node]) a_quantize_node = quantize_graph.create_node( "QuantizeV2", a_quantize_name, [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"]) quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8) graph_def.node.extend([a_quantize_node]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) graph_def.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=dtypes.float32, shape=[]) graph_def.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=dtypes.float32, shape=[]) graph_def.node.extend([b_constant_max]) b_dequantize_node = quantize_graph.create_node( "Dequantize", b_dequantize_name, [b_constant_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8) graph_def.node.extend([b_dequantize_node]) b_quantize_node = quantize_graph.create_node( "QuantizeV2", b_quantize_name, [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"]) quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8) graph_def.node.extend([b_quantize_node]) mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [ a_quantize_name, b_quantize_name, a_quantize_name + ":1", a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2" ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32) graph_def.node.extend([mat_mul_node]) expected_output = graph_pb2.GraphDef() a_constant = quantize_graph.create_constant_node( a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) expected_output.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=dtypes.float32, shape=[]) expected_output.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=dtypes.float32, shape=[]) expected_output.node.extend([a_constant_max]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) expected_output.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=dtypes.float32, shape=[]) expected_output.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=dtypes.float32, shape=[]) expected_output.node.extend([b_constant_max]) mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [ a_constant_name, b_constant_name, a_constant_min_name, a_constant_max_name, b_constant_min_name, b_constant_max_name ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32) expected_output.node.extend([mat_mul_node]) expected_output.versions.CopyFrom(graph_def.versions) expected_output.library.CopyFrom(graph_def.library) rewriter = quantize_graph.GraphRewriter( graph_def, [mat_mul_name], quantized_input_range=None) output = rewriter.remove_redundant_quantization(graph_def) stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name]) self.assertProtoEquals(expected_output, stripped_output)