def assert_meta_graph_protos_equal(tester, a, b): """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" # Carefully check the collection_defs tester.assertEqual(set(a.collection_def), set(b.collection_def)) collection_keys = a.collection_def.keys() for k in collection_keys: a_value = a.collection_def[k] b_value = b.collection_def[k] proto_type = ops.get_collection_proto_type(k) if proto_type: a_proto = proto_type() b_proto = proto_type() # Number of entries in the collections is the same tester.assertEqual(len(a_value.bytes_list.value), len(b_value.bytes_list.value)) for (a_value_item, b_value_item) in zip( a_value.bytes_list.value, b_value.bytes_list.value): a_proto.ParseFromString(a_value_item) b_proto.ParseFromString(b_value_item) tester.assertProtoEquals(a_proto, b_proto) else: tester.assertEquals(a_value, b_value) # Compared the fields directly, remove their raw values from the # proto comparison below. a.ClearField("collection_def") b.ClearField("collection_def") tester.assertProtoEquals(a, b)
def assert_meta_graph_protos_equal(tester, a, b): """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" # Carefully check the collection_defs tester.assertEqual(set(a.collection_def), set(b.collection_def)) collection_keys = a.collection_def.keys() for k in collection_keys: a_value = a.collection_def[k] b_value = b.collection_def[k] proto_type = ops.get_collection_proto_type(k) if proto_type: a_proto = proto_type() b_proto = proto_type() # Number of entries in the collections is the same tester.assertEqual(len(a_value.bytes_list.value), len(b_value.bytes_list.value)) for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, b_value.bytes_list.value): a_proto.ParseFromString(a_value_item) b_proto.ParseFromString(b_value_item) tester.assertProtoEquals(a_proto, b_proto) else: tester.assertEquals(a_value, b_value) # Compared the fields directly, remove their raw values from the # proto comparison below. a.ClearField("collection_def") b.ClearField("collection_def") tester.assertProtoEquals(a, b)
def add_collection_def(meta_graph_def, key, graph=None, export_scope=None): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. graph: The `Graph` from which to get collections. export_scope: Optional `string`. Name scope to remove. """ if graph and not isinstance(graph, ops.Graph): raise TypeError("graph must be of type Graph, not %s", type(graph)) if not isinstance(key, six.string_types) and not isinstance(key, bytes): logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key)) return # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() collection_list = graph.get_collection(key) if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x, export_scope=export_scope) if proto: assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": for x in collection_list: if not export_scope or x.name.startswith(export_scope): getattr(col_def, kind).value.append( ops.strip_name_scope(x.name, export_scope)) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend( [compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning("Error encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s", key, str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def add_collection_def(meta_graph_def, key): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. """ if not isinstance(key, six.string_types) and not isinstance(key, bytes): logging.warning( "Only collections with string type keys will be " "serialized. This key has %s", type(key)) return collection_list = ops.get_collection(key) if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x) if not isinstance(proto, proto_type): raise TypeError("proto %s is not type %s" % (proto, proto_type)) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": getattr(col_def, kind).value.extend([x.name for x in collection_list]) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend( [compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning( "Error encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s", key, str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def _import_meta_graph_def(meta_graph_def): """Recreates a Graph saved in a a `MetaGraphDef` proto. This function adds all the nodes from the meta graph def proto to the current graph, recreates all the collections, and returns a saver from saver_def. Args: meta_graph_def: `MetaGraphDef` protocol buffer. Returns: A saver constructed rom `saver_def` in `meta_graph_def`. """ # Gathers the list of nodes we are interested in. importer.import_graph_def(meta_graph_def.graph_def, name="") # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping." % key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) ops.add_to_collection(key, from_proto(proto)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = ops.get_default_graph().as_graph_element(value) ops.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: ops.add_to_collection(key, int(value)) else: for value in field.value: ops.add_to_collection(key, value) if meta_graph_def.HasField("saver_def"): return Saver(saver_def=meta_graph_def.saver_def) else: return Saver()
def _get_all_protos_from_collection(meta_graph_def, collection_key): """Obtain node names from a collection.""" if collection_key not in meta_graph_def.collection_def: return [] collection = meta_graph_def.collection_def[collection_key] if not collection.bytes_list.value: raise ValueError( 'Collection {} is present but type is not bytes_list.'.format( collection_key)) proto_type = _ops.get_collection_proto_type(collection_key) result = [] for value in collection.bytes_list.value: proto = proto_type() proto.ParseFromString(value) result.append(proto) return result
def add_collection_def(meta_graph_def, key): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. """ if not isinstance(key, six.string_types) and not isinstance(key, bytes): logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key)) return collection_list = ops.get_collection(key) if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x) if not isinstance(proto, proto_type): raise TypeError("proto %s is not type %s" % (proto, proto_type)) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": getattr(col_def, kind).value.extend([x.name for x in collection_list]) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend([compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning( "Error encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s", key, str(e), ) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection( key, ops.prepend_name_scope(value, scope))
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection(key, ops.prepend_name_scope(value, scope))
def _tf_meta_graph_def_to_lgf_meta_graph_info(self, meta_graph_def): # Store the serialized partial meta graph def partial_meta_graph_def = tf.MetaGraphDef() partial_meta_graph_def.saver_def.CopyFrom(meta_graph_def.saver_def) for k, v in meta_graph_def.collection_def.items(): partial_meta_graph_def.collection_def[k].CopyFrom(v) for k, v in meta_graph_def.signature_def.items(): partial_meta_graph_def.signature_def[k].CopyFrom(v) partial_meta_graph_def.graph_def.CopyFrom(meta_graph_def.graph_def) del (partial_meta_graph_def.graph_def.node[:]) meta_graph_info = lgf_pb2.MetaGraphInfo() meta_graph_info.original_graph_info[ ImportTFSavedModelBase. PARTIAL_META_GRAPH_DEF].v = partial_meta_graph_def.SerializeToString( ) # Get all strings from the partial meta graph def proto_strings = self.get_strings_from_proto(partial_meta_graph_def) # Need to manually deserialize the collection defs for k, collection_def in partial_meta_graph_def.collection_def.items(): if collection_def.HasField("bytes_list"): proto_type = tf_ops.get_collection_proto_type(k) for serialized_proto in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(serialized_proto) proto_strings.extend(self.get_strings_from_proto(proto)) # Get all the node names from proto_strings node_names = {n.name for n in self._graph_def.node} for string in proto_strings: if string != "": name = self.get_node_name_and_output_index(string)[0] if name in node_names: self._required_nodes.add(name) # Add required nodes to meta_graph_info meta_graph_info.required_nodes.extend(self._required_nodes) return meta_graph_info
def _add_collection_def(meta_graph_def, key): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. """ if not isinstance(key, (str, bytes, unicode)): logging.warning("Only collections with string type keys will be " "serialized. This key has %s" % type(key)) return collection_list = ops.get_collection(key) if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x) assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": getattr(col_def, kind).value.extend([x.name for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception, e: # pylint: disable=broad-except logging.warning("Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s" % str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def import_scoped_meta_graph_with_return_elements( meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True), return_elements=None): """Imports graph from `MetaGraphDef` and returns vars and return elements. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. return_elements: A list of strings containing operation names in the `MetaGraphDef` that will be returned as `Operation` objects; and/or tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. Returns: A tuple of ( dictionary of all the `Variables` imported into the name scope, list of `Operation` or `Tensor` objects from the `return_elements` list). Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError( "Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and (not input_map or sorted( [compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError( "Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join( compat.as_str(v) for v in field.value if not input_map or v not in input_map)) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name(import_scope or "", mark_as_used=False) imported_return_elements = importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list, return_elements=return_elements) # TensorFlow versions before 1.9 (not inclusive) exported SavedModels # without a VariableDef.trainable field set. tf_version = meta_graph_def.meta_info_def.tensorflow_version if not tf_version: variables_have_trainable = True else: variables_have_trainable = (packaging_version.parse(tf_version) >= packaging_version.parse("1.9")) # Sort collections so we see TRAINABLE_VARIABLES first and can default these # variables to trainable if the value is not set in their VariableDef. sorted_collections = [] if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES, meta_graph_def.collection_def[ ops.GraphKeys.TRAINABLE_VARIABLES])) for key, value in sorted(meta_graph_def.collection_def.items()): if key != ops.GraphKeys.TRAINABLE_VARIABLES: sorted_collections.append((key, value)) # Restores all the other collections. variable_objects = {} for key, col_def in sorted_collections: # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) # Temporary change to allow the TFMA evaluator to read metric variables # saved as a bytes list. # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. if key == ops.GraphKeys.METRIC_VARIABLES: # Metric variables will use the same proto functions as GLOBAL_VARIABLES from_proto = ops.get_from_proto_function( ops.GraphKeys.GLOBAL_VARIABLES) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) if not variables_have_trainable: # If the VariableDef proto does not contain a "trainable" # property because it was exported before that property was # added, we default it to whether the variable is in the # TRAINABLE_VARIABLES collection. We've sorted # TRAINABLE_VARIABLES to be first, so trainable variables will # be created from that collection. proto.trainable = ( key == ops.GraphKeys.TRAINABLE_VARIABLES) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list, imported_return_elements
def add_collection_def(meta_graph_def, key, graph=None, export_scope=None, exclude_nodes=None, override_contents=None): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. graph: The `Graph` from which to get collections. export_scope: Optional `string`. Name scope to remove. exclude_nodes: An iterable of nodes or `string` node names to omit from the collection, or None. override_contents: An iterable of values to place in the collection, ignoring the current values (if set). """ if graph and not isinstance(graph, ops.Graph): raise TypeError( f"graph must be of type Graph. Received type: {type(graph)}.") if not isinstance(key, str) and not isinstance(key, bytes): logging.warning( "Only collections with string type keys will be " "serialized. This key has %s", type(key)) return # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() if override_contents: collection_list = override_contents else: collection_list = graph.get_collection(key) # Remove nodes that should not be exported from the collection list. collection_list = [ x for x in collection_list if _should_include_node(x, export_scope, exclude_nodes) ] if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x, export_scope=export_scope) if proto: assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": for x in collection_list: if not export_scope or x.name.startswith(export_scope): getattr(col_def, kind).value.append( ops.strip_name_scope(x.name, export_scope)) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend( [compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning( "Issue encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef. Note this is a warning " "and probably safe to ignore.\n%s", key, str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto( proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.in_eager_mode(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE (opensource): This force conversion is to work around the fact id:3223 gh:3224 # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs"): """Recreates a`Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates all the collections, and returns a saver constructed from the `saver_def` field. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=import_scope)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, import_scope)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, import_scope)) var_list = {} variables = graph.get_collection(ops.GraphKeys.VARIABLES, scope=import_scope) for v in variables: var_list[ops.strip_name_scope(v.name, import_scope)] = v return var_list