def _import_meta_graph_def(meta_graph_def): """Recreates a Graph saved in a a `MetaGraphDef` proto. This function adds all the nodes from the meta graph def proto to the current graph, recreates all the collections, and returns a saver from saver_def. Args: meta_graph_def: `MetaGraphDef` protocol buffer. Returns: A saver constructed rom `saver_def` in `meta_graph_def`. """ # Gathers the list of nodes we are interested in. importer.import_graph_def(meta_graph_def.graph_def, name="") # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping." % key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) ops.add_to_collection(key, from_proto(proto)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = ops.get_default_graph().as_graph_element(value) ops.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: ops.add_to_collection(key, int(value)) else: for value in field.value: ops.add_to_collection(key, value) if meta_graph_def.HasField("saver_def"): return Saver(saver_def=meta_graph_def.saver_def) else: return Saver()
def _handle_collection_def(multi_gpu_meta_graph_def, op_names_to_replicate, num_replicas): allow_bytes_list_keys = [tf.GraphKeys.QUEUE_RUNNERS, tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES, tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.GRADIENTS_INFO, tf.GraphKeys.GLOBAL_STEP] keys_to_remove = [] for key, col_def in multi_gpu_meta_graph_def.collection_def.items(): kind = col_def.WhichOneof("kind") # Update node_list collections (e.g., GLOBAL_STEP, TRAIN_OP, UPDATE_OP, # LOSSES, ...) if kind == 'node_list': new_col_def = get_new_col_def_of_node_list( col_def, op_names_to_replicate, num_replicas) multi_gpu_meta_graph_def.collection_def[key].Clear() multi_gpu_meta_graph_def.collection_def[key].CopyFrom(new_col_def) elif kind == 'bytes_list': if ops.get_from_proto_function(key): # Collections in allow_bytes_list_keys will be handled # explicitly below # (e.g., QUEUE_RUNNERS, LOCAL_VARIABLES, ...) if key in allow_bytes_list_keys: continue # Remove unhandled collections (e.g., COND_CONTEXT) # TODO: Handle all protos in tf.GraphKeys else: keys_to_remove.append(key) # Keep collections without proto function # (e.g., user defined string) else: continue else: raise RuntimeError("Should not reach here") for key in keys_to_remove: del multi_gpu_meta_graph_def.collection_def[key] # Update QUEUE_RUNNERS and LOCAL_VARIABLES collection update_queue_runners(multi_gpu_meta_graph_def, op_names_to_replicate, num_replicas) update_local_variables(multi_gpu_meta_graph_def, op_names_to_replicate, num_replicas) update_shard_info_for_in_graph(multi_gpu_meta_graph_def, num_replicas)
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection( key, ops.prepend_name_scope(value, scope))
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection(key, ops.prepend_name_scope(value, scope))
def import_scoped_meta_graph_with_return_elements( meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True), return_elements=None): """Imports graph from `MetaGraphDef` and returns vars and return elements. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. return_elements: A list of strings containing operation names in the `MetaGraphDef` that will be returned as `Operation` objects; and/or tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. Returns: A tuple of ( dictionary of all the `Variables` imported into the name scope, list of `Operation` or `Tensor` objects from the `return_elements` list). Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError( "Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and (not input_map or sorted( [compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError( "Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join( compat.as_str(v) for v in field.value if not input_map or v not in input_map)) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name(import_scope or "", mark_as_used=False) imported_return_elements = importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list, return_elements=return_elements) # TensorFlow versions before 1.9 (not inclusive) exported SavedModels # without a VariableDef.trainable field set. tf_version = meta_graph_def.meta_info_def.tensorflow_version if not tf_version: variables_have_trainable = True else: variables_have_trainable = (packaging_version.parse(tf_version) >= packaging_version.parse("1.9")) # Sort collections so we see TRAINABLE_VARIABLES first and can default these # variables to trainable if the value is not set in their VariableDef. sorted_collections = [] if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES, meta_graph_def.collection_def[ ops.GraphKeys.TRAINABLE_VARIABLES])) for key, value in sorted(meta_graph_def.collection_def.items()): if key != ops.GraphKeys.TRAINABLE_VARIABLES: sorted_collections.append((key, value)) # Restores all the other collections. variable_objects = {} for key, col_def in sorted_collections: # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) # Temporary change to allow the TFMA evaluator to read metric variables # saved as a bytes list. # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. if key == ops.GraphKeys.METRIC_VARIABLES: # Metric variables will use the same proto functions as GLOBAL_VARIABLES from_proto = ops.get_from_proto_function( ops.GraphKeys.GLOBAL_VARIABLES) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) if not variables_have_trainable: # If the VariableDef proto does not contain a "trainable" # property because it was exported before that property was # added, we default it to whether the variable is in the # TRAINABLE_VARIABLES collection. We've sorted # TRAINABLE_VARIABLES to be first, so trainable variables will # be created from that collection. proto.trainable = ( key == ops.GraphKeys.TRAINABLE_VARIABLES) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list, imported_return_elements
def create_meta_graph_def(meta_info_def=None, graph_def=None, saver_def=None, collection_list=None, graph=None, export_scope=None, exclude_nodes=None, clear_extraneous_savers=False, strip_default_attrs=False): # pylint: disable=line-too-long """Construct and returns a `MetaGraphDef` protocol buffer. Args: meta_info_def: `MetaInfoDef` protocol buffer. graph_def: `GraphDef` protocol buffer. saver_def: `SaverDef` protocol buffer. collection_list: List of string keys to collect. graph: The `Graph` to create `MetaGraphDef` out of. export_scope: Optional `string`. Name scope to remove. exclude_nodes: An iterable of nodes or `string` node names to omit from all collection, or None. clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS collection. Note this method does not alter the graph, so any extraneous Save/Restore ops should have been removed already, as needed. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: MetaGraphDef protocol buffer. Raises: TypeError: If the arguments are not of the correct proto buffer type. """ # pylint: enable=line-too-long # Type check. if graph and not isinstance(graph, ops.Graph): raise TypeError( f"graph must be of type Graph. Received type: {type(graph)}.") if meta_info_def and not isinstance( meta_info_def, meta_graph_pb2.MetaGraphDef.MetaInfoDef): raise TypeError("meta_info_def must be of type MetaInfoDef. " f"Received type: {type(meta_info_def)}.") if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be of type GraphDef. " f"Received type: {type(graph_def)}.") if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): raise TypeError(f"saver_def must be of type SaverDef. " f"Received type: {type(saver_def)}.") # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Creates a MetaGraphDef proto. meta_graph_def = meta_graph_pb2.MetaGraphDef() # Adds meta_info_def. if not meta_info_def: meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() # Set the tf version strings to the current tf build. meta_info_def.tensorflow_version = versions.__version__ meta_info_def.tensorflow_git_version = versions.__git_version__ meta_graph_def.meta_info_def.MergeFrom(meta_info_def) # Adds graph_def or the default. if not graph_def: meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) else: meta_graph_def.graph_def.MergeFrom(graph_def) # Fills in meta_info_def.stripped_op_list using the ops from graph_def. # pylint: disable=g-explicit-length-test if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( stripped_op_list_for_graph(meta_graph_def.graph_def)) # pylint: enable=g-explicit-length-test # Strip default valued attributes in graph_def. if strip_default_attrs: strip_graph_default_valued_attrs(meta_graph_def) # Adds saver_def. if saver_def: meta_graph_def.saver_def.MergeFrom(saver_def) # Adds collection_list. if collection_list is not None: clist = collection_list else: clist = graph.get_all_collection_keys() for ctype in clist: if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS: # Avoid importing Saver here from_proto = ops.get_from_proto_function(ctype) add_collection_def(meta_graph_def, ctype, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes, override_contents=[from_proto(saver_def)]) else: add_collection_def(meta_graph_def, ctype, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes) return meta_graph_def
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto( proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def create_meta_graph_def(meta_info_def=None, graph_def=None, saver_def=None, collection_list=None, graph=None, export_scope=None, exclude_nodes=None, clear_extraneous_savers=False, strip_default_attrs=False): # pylint: disable=line-too-long """Construct and returns a `MetaGraphDef` protocol buffer. Args: meta_info_def: `MetaInfoDef` protocol buffer. graph_def: `GraphDef` protocol buffer. saver_def: `SaverDef` protocol buffer. collection_list: List of string keys to collect. graph: The `Graph` to create `MetaGraphDef` out of. export_scope: Optional `string`. Name scope to remove. exclude_nodes: An iterable of nodes or `string` node names to omit from all collection, or None. clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS collection. Note this method does not alter the graph, so any extraneous Save/Restore ops should have been removed already, as needed. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: MetaGraphDef protocol buffer. Raises: TypeError: If the arguments are not of the correct proto buffer type. """ # pylint: enable=line-too-long # Type check. if graph and not isinstance(graph, ops.Graph): raise TypeError("graph must be of type Graph, not %s", type(graph)) if meta_info_def and not isinstance(meta_info_def, meta_graph_pb2.MetaGraphDef.MetaInfoDef): raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", type(meta_info_def)) if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be of type GraphDef, not %s", type(graph_def)) if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): raise TypeError("saver_def must be of type SaverDef, not %s", type(saver_def)) # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Creates a MetaGraphDef proto. meta_graph_def = meta_graph_pb2.MetaGraphDef() # Adds meta_info_def. if not meta_info_def: meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() # Set the tf version strings to the current tf build. meta_info_def.tensorflow_version = versions.__version__ meta_info_def.tensorflow_git_version = versions.__git_version__ meta_graph_def.meta_info_def.MergeFrom(meta_info_def) # Adds graph_def or the default. if not graph_def: meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) else: meta_graph_def.graph_def.MergeFrom(graph_def) # Fills in meta_info_def.stripped_op_list using the ops from graph_def. # pylint: disable=g-explicit-length-test if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( stripped_op_list_for_graph(meta_graph_def.graph_def)) # pylint: enable=g-explicit-length-test # Strip default valued attributes in graph_def. if strip_default_attrs: _strip_graph_default_valued_attrs(meta_graph_def) # Adds saver_def. if saver_def: meta_graph_def.saver_def.MergeFrom(saver_def) # Adds collection_list. if collection_list is not None: clist = collection_list else: clist = graph.get_all_collection_keys() for ctype in clist: if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS: # Avoid importing Saver here from_proto = ops.get_from_proto_function(ctype) add_collection_def(meta_graph_def, ctype, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes, override_contents=[from_proto(saver_def)]) else: add_collection_def(meta_graph_def, ctype, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes) return meta_graph_def
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.in_eager_mode(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE (opensource): This force conversion is to work around the fact id:3223 gh:3224 # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs"): """Recreates a`Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates all the collections, and returns a saver constructed from the `saver_def` field. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=import_scope)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, import_scope)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, import_scope)) var_list = {} variables = graph.get_collection(ops.GraphKeys.VARIABLES, scope=import_scope) for v in variables: var_list[ops.strip_name_scope(v.name, import_scope)] = v return var_list