def to_proto(self, export_scope=None): """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. Args: export_scope: Optional `string`. Name scope to remove. Returns: A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in the specified name scope. """ if (export_scope is None or self.queue.name.startswith(export_scope)): queue_runner_def = queue_runner_pb2.QueueRunnerDef() queue_runner_def.queue_name = ops.strip_name_scope( self.queue.name, export_scope) for enqueue_op in self.enqueue_ops: queue_runner_def.enqueue_op_name.append( ops.strip_name_scope(enqueue_op.name, export_scope)) queue_runner_def.close_op_name = ops.strip_name_scope( self.close_op.name, export_scope) queue_runner_def.cancel_op_name = ops.strip_name_scope( self.cancel_op.name, export_scope) queue_runner_def.queue_closed_exception_types.extend([ errors.error_code_from_exception_type(cls) for cls in self._queue_closed_exception_types]) return queue_runner_def else: return None
def to_proto(self, export_scope=None): """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. Args: export_scope: Optional `string`. Name scope to remove. Returns: A `VariableDef` protocol buffer, or `None` if the `Variable` is not in the specified name scope. """ if (export_scope is None or self.handle.name.startswith(export_scope)): var_def = variable_pb2.VariableDef() var_def.variable_name = ops.strip_name_scope( self.handle.name, export_scope) var_def.initializer_name = ops.strip_name_scope( self.initializer.name, export_scope) if self._cached_value is not None: var_def.snapshot_name = ops.strip_name_scope( self._cached_value.name, export_scope) var_def.is_resource = True if self._save_slice_info: var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto( export_scope=export_scope)) return var_def else: return None
def add_collection_def(meta_graph_def, key, graph=None, export_scope=None): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. graph: The `Graph` from which to get collections. export_scope: Optional `string`. Name scope to remove. """ if graph and not isinstance(graph, ops.Graph): raise TypeError("graph must be of type Graph, not %s", type(graph)) if not isinstance(key, six.string_types) and not isinstance(key, bytes): logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key)) return # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() collection_list = graph.get_collection(key) if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x, export_scope=export_scope) if proto: assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": for x in collection_list: if not export_scope or x.name.startswith(export_scope): getattr(col_def, kind).value.append( ops.strip_name_scope(x.name, export_scope)) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend( [compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning("Error encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s", key, str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): """Create a `NodeDef` proto with export_scope stripped. Args: from_node_def: A `node_def_pb2.NodeDef` protocol buffer. export_scope: A `string` representing the name scope to remove. unbound_inputs: An array of unbound input names if they exist. clear_devices: Boolean which controls whether to clear device information from node_def. Default false. Returns: A `node_def_pb2.NodeDef` protocol buffer. """ node_def = copy.deepcopy(from_node_def) for i, v in enumerate(node_def.input): if (export_scope and not node_def.input[i].lstrip("^").startswith(export_scope)): # Adds "$unbound_inputs_" prefix to the unbound name so they are easily # identifiable. node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", compat.as_str(v)) unbound_inputs.append(node_def.input[i]) else: node_def.input[i] = ops.strip_name_scope(v, export_scope) node_def.name = compat.as_bytes( ops.strip_name_scope(from_node_def.name, export_scope)) for k, v in six.iteritems(from_node_def.attr): if k == "_class": new_s = [compat.as_bytes( ops.strip_name_scope(s, export_scope)) for s in v.list.s if not export_scope or compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": if not export_scope or compat.as_str(v.s).startswith(export_scope): new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) else: node_def.attr[k].CopyFrom(v) if clear_devices: node_def.device = "" return node_def
def to_proto(self, export_scope=None): """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. Args: export_scope: Optional `string`. Name scope to remove. Raises: RuntimeError: If run in EAGER mode. Returns: A `VariableDef` protocol buffer, or `None` if the `Variable` is not in the specified name scope. """ if context.executing_eagerly(): raise RuntimeError("to_proto not supported in EAGER mode.") if export_scope is None or self.handle.name.startswith(export_scope): var_def = variable_pb2.VariableDef() var_def.variable_name = ops.strip_name_scope(self.handle.name, export_scope) if self._initial_value is not None: # This is inside an if-statement for backwards compatibility, since # self._initial_value might be None for variables constructed from old # protos. var_def.initial_value_name = ops.strip_name_scope( self._initial_value.name, export_scope) var_def.initializer_name = ops.strip_name_scope(self.initializer.name, export_scope) if self._cached_value is not None: var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, export_scope) else: # Store the graph_element here var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, export_scope) var_def.is_resource = True var_def.trainable = self.trainable if self._save_slice_info: var_def.save_slice_info_def.MergeFrom( self._save_slice_info.to_proto(export_scope=export_scope)) return var_def else: return None
def _node_def(from_node_def, export_scope, unbound_inputs): """Create a `NodeDef` proto with export_scope stripped. Args: from_node_def: A `node_def_pb2.NodeDef` protocol buffer. export_scope: A `string` representing the name scope to remove. unbound_inputs: An array of unbound input names if they exist. Returns: A `node_def_pb2.NodeDef` protocol buffer. """ node_def = copy.deepcopy(from_node_def) for i, v in enumerate(node_def.input): if (export_scope and not node_def.input[i].lstrip("^").startswith(export_scope)): # Adds "$unbound_inputs_" prefix to the unbound name so they are easily # identifiable. node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1$unbound_inputs_\2", compat.as_str(v)) unbound_inputs.append(node_def.input[i]) else: node_def.input[i] = ops.strip_name_scope(v, export_scope) node_def.name = compat.as_bytes( ops.strip_name_scope(from_node_def.name, export_scope)) for k, v in six.iteritems(from_node_def.attr): if k == "_class": new_s = [compat.as_bytes( ops.strip_name_scope(s, export_scope)) for s in v.list.s if not export_scope or compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) else: node_def.attr[k].CopyFrom(v) return node_def
def to_proto(self, export_scope=None): """Returns a SaveSliceInfoDef() proto. Args: export_scope: Optional `string`. Name scope to remove. Returns: A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not in the specified name scope. """ if export_scope is None or self.full_name.startswith(export_scope): save_slice_info_def = variable_pb2.SaveSliceInfoDef() save_slice_info_def.full_name = ops.strip_name_scope(self.full_name, export_scope) for i in self.full_shape: save_slice_info_def.full_shape.append(i) for i in self.var_offset: save_slice_info_def.var_offset.append(i) for i in self.var_shape: save_slice_info_def.var_shape.append(i) return save_slice_info_def else: return None
def export_scoped_meta_graph(filename=None, graph_def=None, graph=None, export_scope=None, as_text=False, unbound_inputs_col_name="unbound_inputs", clear_devices=False, saver_def=None, clear_extraneous_savers=False, strip_default_attrs=False, save_debug_info=False, **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph. Args: filename: Optional filename including the path for writing the generated `MetaGraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer. graph: The `Graph` to export. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be stripped from the node definitions for easy import later into new name scopes. If `None`, the whole graph is exported. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. unbound_inputs_col_name: Optional `string`. If provided, a string collection with the given name will be added to the returned `MetaGraphDef`, containing the names of tensors that must be remapped when importing the `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information before exporting the graph. saver_def: `SaverDef` protocol buffer. clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with the provided SaverDef. strip_default_attrs: Set to true if default valued attributes must be removed while exporting the GraphDef. save_debug_info: If `True`, save the GraphDebugInfo to a separate file, which in the same directory of filename and with `_debug` added before the file extension. **kwargs: Optional keyed arguments, including meta_info_def and collection_list. Returns: A `MetaGraphDef` proto and dictionary of `Variables` in the exported name scope. Raises: ValueError: When the `GraphDef` is larger than 2GB. ValueError: When executing in Eager mode and either `graph_def` or `graph` is undefined. """ if context.executing_eagerly() and not (graph_def is not None and graph is not None): raise ValueError("Exporting/importing meta graphs is not supported when " "Eager Execution is enabled.") graph = graph or ops.get_default_graph() exclude_nodes = None unbound_inputs = [] if export_scope or clear_extraneous_savers or clear_devices: if graph_def: new_graph_def = graph_pb2.GraphDef() new_graph_def.versions.CopyFrom(graph_def.versions) new_graph_def.library.CopyFrom(graph_def.library) if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) for node_def in graph_def.node: if _should_include_node(node_def.name, export_scope, exclude_nodes): new_node_def = _node_def(node_def, export_scope, unbound_inputs, clear_devices=clear_devices) new_graph_def.node.extend([new_node_def]) graph_def = new_graph_def else: # Only do this complicated work if we want to remove a name scope. graph_def = graph_pb2.GraphDef() # pylint: disable=protected-access graph_def.versions.CopyFrom(graph.graph_def_versions) bytesize = 0 if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), saver_def) for key in sorted(graph._nodes_by_id): if _should_include_node(graph._nodes_by_id[key].name, export_scope, exclude_nodes): value = graph._nodes_by_id[key] # pylint: enable=protected-access node_def = _node_def(value.node_def, export_scope, unbound_inputs, clear_devices=clear_devices) graph_def.node.extend([node_def]) if value.outputs: assert "_output_shapes" not in graph_def.node[-1].attr graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ output.get_shape().as_proto() for output in value.outputs]) bytesize += value.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access # It's possible that not all the inputs are in the export_scope. # If we would like such information included in the exported meta_graph, # add them to a special unbound_inputs collection. if unbound_inputs_col_name: # Clears the unbound_inputs collections. graph.clear_collection(unbound_inputs_col_name) for k in unbound_inputs: graph.add_to_collection(unbound_inputs_col_name, k) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=export_scope) for v in variables: if _should_include_node(v, export_scope, exclude_nodes): var_list[ops.strip_name_scope(v.name, export_scope)] = v scoped_meta_graph_def = create_meta_graph_def( graph_def=graph_def, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes, clear_extraneous_savers=clear_extraneous_savers, saver_def=saver_def, strip_default_attrs=strip_default_attrs, **kwargs) if filename: graph_io.write_graph( scoped_meta_graph_def, os.path.dirname(filename), os.path.basename(filename), as_text=as_text) if save_debug_info: name, _ = os.path.splitext(filename) debug_filename = "{name}{ext}".format(name=name, ext=".debug") # Gets the operation from the graph by the name. Exludes variable nodes, # so only the nodes in the frozen models are included. # TODO(liufengdb): fix this for functions. ops_to_export = [] for node in scoped_meta_graph_def.graph_def.node: scoped_op_name = ops.prepend_name_scope(node.name, export_scope) ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name))) graph_debug_info = error_interpolation.create_graph_debug_info_def( ops_to_export) graph_io.write_graph( graph_debug_info, os.path.dirname(debug_filename), os.path.basename(debug_filename), as_text=as_text) return scoped_meta_graph_def, var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto( proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def export_scoped_meta_graph(filename=None, graph_def=None, graph=None, export_scope=None, as_text=False, unbound_inputs_col_name="unbound_inputs", clear_devices=False, saver_def=None, clear_extraneous_savers=False, strip_default_attrs=False, **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph. Args: filename: Optional filename including the path for writing the generated `MetaGraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer. graph: The `Graph` to export. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be stripped from the node definitions for easy import later into new name scopes. If `None`, the whole graph is exported. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. unbound_inputs_col_name: Optional `string`. If provided, a string collection with the given name will be added to the returned `MetaGraphDef`, containing the names of tensors that must be remapped when importing the `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information before exporting the graph. saver_def: `SaverDef` protocol buffer. clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with the provided SaverDef. strip_default_attrs: Set to true if default valued attributes must be removed while exporting the GraphDef. **kwargs: Optional keyed arguments, including meta_info_def and collection_list. Returns: A `MetaGraphDef` proto and dictionary of `Variables` in the exported name scope. Raises: ValueError: When the `GraphDef` is larger than 2GB. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "Eager Execution is enabled.") graph = graph or ops.get_default_graph() exclude_nodes = None unbound_inputs = [] if export_scope or clear_extraneous_savers or clear_devices: if graph_def: new_graph_def = graph_pb2.GraphDef() new_graph_def.versions.CopyFrom(graph_def.versions) new_graph_def.library.CopyFrom(graph_def.library) if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) for node_def in graph_def.node: if _should_include_node(node_def.name, export_scope, exclude_nodes): new_node_def = _node_def(node_def, export_scope, unbound_inputs, clear_devices=clear_devices) new_graph_def.node.extend([new_node_def]) graph_def = new_graph_def else: # Only do this complicated work if we want to remove a name scope. graph_def = graph_pb2.GraphDef() # pylint: disable=protected-access graph_def.versions.CopyFrom(graph.graph_def_versions) bytesize = 0 if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), saver_def) for key in sorted(graph._nodes_by_id): if _should_include_node(graph._nodes_by_id[key].name, export_scope, exclude_nodes): value = graph._nodes_by_id[key] # pylint: enable=protected-access node_def = _node_def(value.node_def, export_scope, unbound_inputs, clear_devices=clear_devices) graph_def.node.extend([node_def]) if value.outputs: assert "_output_shapes" not in graph_def.node[-1].attr graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ output.get_shape().as_proto() for output in value.outputs]) bytesize += value.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access # It's possible that not all the inputs are in the export_scope. # If we would like such information included in the exported meta_graph, # add them to a special unbound_inputs collection. if unbound_inputs_col_name: # Clears the unbound_inputs collections. graph.clear_collection(unbound_inputs_col_name) for k in unbound_inputs: graph.add_to_collection(unbound_inputs_col_name, k) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=export_scope) for v in variables: if _should_include_node(v, export_scope, exclude_nodes): var_list[ops.strip_name_scope(v.name, export_scope)] = v scoped_meta_graph_def = create_meta_graph_def( graph_def=graph_def, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes, clear_extraneous_savers=clear_extraneous_savers, saver_def=saver_def, strip_default_attrs=strip_default_attrs, **kwargs) if filename: graph_io.write_graph( scoped_meta_graph_def, os.path.dirname(filename), os.path.basename(filename), as_text=as_text) return scoped_meta_graph_def, var_list
def export_scoped_meta_graph(filename=None, graph_def=None, graph=None, export_scope=None, as_text=False, unbound_inputs_col_name="unbound_inputs", **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph. Args: filename: Optional filename including the path for writing the generated `MetaGraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer. graph: The `Graph` to import into. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be striped from the node definitions for easy import later into new name scopes. If `None`, the whole graph is exported. graph_def and export_scope cannot both be specified. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. unbound_inputs_col_name: Optional `string`. If provided, a string collection with the given name will be added to the returned `MetaGraphDef`, containing the names of tensors that must be remapped when importing the `MetaGraphDef`. **kwargs: Optional keyed arguments, including meta_info_def, saver_def, collection_list. Returns: A `MetaGraphDef` proto and dictionary of `Variables` in the exported name scope. Raises: ValueError: When the `GraphDef` is larger than 2GB. """ graph = graph or ops.get_default_graph() if graph_def and export_scope: raise ValueError("graph_def and export_scope cannot both " "be specified.") if graph_def is None and export_scope: unbound_inputs = [] # Only do this complicated work if we want to remove a name scope. graph_def = graph_pb2.GraphDef() # pylint: disable=protected-access graph_def.versions.CopyFrom(graph._graph_def_versions) bytesize = 0 for key in sorted(graph._nodes_by_name): if _should_include_node(key, export_scope): value = graph._nodes_by_name[key] # pylint: enable=protected-access graph_def.node.extend([_node_def(value.node_def, export_scope, unbound_inputs)]) if value.outputs: assert "_output_shapes" not in graph_def.node[-1].attr graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ output.get_shape().as_proto() for output in value.outputs]) bytesize += value.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") # It's possible that not all the inputs are in the export_scope. # If we would like such information included in the exported meta_graph, # add them to a special unbound_inputs collection. if unbound_inputs_col_name: # Clears the unbound_inputs collections. graph.clear_collection(unbound_inputs_col_name) for k in unbound_inputs: graph.add_to_collection(unbound_inputs_col_name, k) var_list = {} variables = graph.get_collection(ops.GraphKeys.VARIABLES, scope=export_scope) for v in variables: if _should_include_node(v, export_scope): var_list[ops.strip_name_scope(v.name, export_scope)] = v scoped_meta_graph_def = create_meta_graph_def( graph_def=graph_def, graph=graph, export_scope=export_scope, **kwargs) if filename: training_util.write_graph( scoped_meta_graph_def, os.path.dirname(filename), os.path.basename(filename), as_text=as_text) return scoped_meta_graph_def, var_list
def __iadd__(self, other): logging.log_first_n( logging.WARN, return self + other def __isub__(self, other): logging.log_first_n( logging.WARN, return self - other def __imul__(self, other): logging.log_first_n( logging.WARN, return self * other def __idiv__(self, other): logging.log_first_n( logging.WARN, return self / other def __itruediv__(self, other): logging.log_first_n( logging.WARN, return self / other def __irealdiv__(self, other): logging.log_first_n( logging.WARN, return self / other def __ipow__(self, other): logging.log_first_n( logging.WARN, return self ** other class SaveSliceInfo(object): def __init__(self, full_name=None, full_shape=None, var_offset=None, var_shape=None, save_slice_info_def=None, import_scope=None): if save_slice_info_def: assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) self.full_name = ops.prepend_name_scope( save_slice_info_def.full_name, import_scope=import_scope) self.full_shape = [i for i in save_slice_info_def.full_shape] self.var_offset = [i for i in save_slice_info_def.var_offset] self.var_shape = [i for i in save_slice_info_def.var_shape] else: self.full_name = full_name self.full_shape = full_shape self.var_offset = var_offset self.var_shape = var_shape @property def spec(self): full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " " sl_spec = ":".join([ "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape) ]) return full_shape_str + sl_spec def to_proto(self, export_scope=None): if (export_scope is None or self.full_name.startswith(export_scope)): save_slice_info_def = variable_pb2.SaveSliceInfoDef() save_slice_info_def.full_name = ops.strip_name_scope( self.full_name, export_scope) for i in self.full_shape: save_slice_info_def.full_shape.append(i) for i in self.var_offset: save_slice_info_def.var_offset.append(i) for i in self.var_shape: save_slice_info_def.var_shape.append(i) return save_slice_info_def else: return None def _set_save_slice_info(self, save_slice_info): self._save_slice_info = save_slice_info def _get_save_slice_info(self): return self._save_slice_info
def export_scoped_meta_graph(filename=None, graph_def=None, graph=None, export_scope=None, as_text=False, unbound_inputs_col_name="unbound_inputs", clear_devices=False, **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph. Args: filename: Optional filename including the path for writing the generated `MetaGraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer. graph: The `Graph` to import into. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be striped from the node definitions for easy import later into new name scopes. If `None`, the whole graph is exported. graph_def and export_scope cannot both be specified. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. unbound_inputs_col_name: Optional `string`. If provided, a string collection with the given name will be added to the returned `MetaGraphDef`, containing the names of tensors that must be remapped when importing the `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information before exporting the graph. **kwargs: Optional keyed arguments, including meta_info_def, saver_def, collection_list. Returns: A `MetaGraphDef` proto and dictionary of `Variables` in the exported name scope. Raises: ValueError: When the `GraphDef` is larger than 2GB. """ graph = graph or ops.get_default_graph() unbound_inputs = [] if export_scope or clear_devices: if graph_def: new_graph_def = graph_pb2.GraphDef() new_graph_def.versions.CopyFrom(graph_def.versions) for node_def in graph_def.node: if _should_include_node(node_def.name, export_scope): new_node_def = _node_def(node_def, export_scope, unbound_inputs, clear_devices=clear_devices) new_graph_def.node.extend([new_node_def]) graph_def = new_graph_def else: # Only do this complicated work if we want to remove a name scope. graph_def = graph_pb2.GraphDef() # pylint: disable=protected-access graph_def.versions.CopyFrom(graph.graph_def_versions) bytesize = 0 for key in sorted(graph._nodes_by_id): if _should_include_node(graph._nodes_by_id[key].name, export_scope): value = graph._nodes_by_id[key] # pylint: enable=protected-access node_def = _node_def(value.node_def, export_scope, unbound_inputs, clear_devices=clear_devices) graph_def.node.extend([node_def]) if value.outputs: assert "_output_shapes" not in graph_def.node[-1].attr graph_def.node[-1].attr[ "_output_shapes"].list.shape.extend([ output.get_shape().as_proto() for output in value.outputs ]) bytesize += value.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") # It's possible that not all the inputs are in the export_scope. # If we would like such information included in the exported meta_graph, # add them to a special unbound_inputs collection. if unbound_inputs_col_name: # Clears the unbound_inputs collections. graph.clear_collection(unbound_inputs_col_name) for k in unbound_inputs: graph.add_to_collection(unbound_inputs_col_name, k) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=export_scope) for v in variables: if _should_include_node(v, export_scope): var_list[ops.strip_name_scope(v.name, export_scope)] = v scoped_meta_graph_def = create_meta_graph_def(graph_def=graph_def, graph=graph, export_scope=export_scope, **kwargs) if filename: graph_io.write_graph(scoped_meta_graph_def, os.path.dirname(filename), os.path.basename(filename), as_text=as_text) return scoped_meta_graph_def, var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs"): """Recreates a`Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates all the collections, and returns a saver constructed from the `saver_def` field. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Output` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Output` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=import_scope)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, import_scope)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, import_scope)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=import_scope) for v in variables: var_list[ops.strip_name_scope(v.name, import_scope)] = v return var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs"): """Recreates a`Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates all the collections, and returns a saver constructed from the `saver_def` field. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=import_scope)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, import_scope)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, import_scope)) var_list = {} variables = graph.get_collection(ops.GraphKeys.VARIABLES, scope=import_scope) for v in variables: var_list[ops.strip_name_scope(v.name, import_scope)] = v return var_list
def add_collection_def(meta_graph_def, key, graph=None, export_scope=None, exclude_nodes=None, override_contents=None): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. graph: The `Graph` from which to get collections. export_scope: Optional `string`. Name scope to remove. exclude_nodes: An iterable of nodes or `string` node names to omit from the collection, or None. override_contents: An iterable of values to place in the collection, ignoring the current values (if set). """ if graph and not isinstance(graph, ops.Graph): raise TypeError("graph must be of type Graph, not %s", type(graph)) if not isinstance(key, six.string_types) and not isinstance(key, bytes): logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key)) return # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() if override_contents: collection_list = override_contents else: collection_list = graph.get_collection(key) # Remove nodes that should not be exported from the collection list. collection_list = [x for x in collection_list if _should_include_node(x, export_scope, exclude_nodes)] if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x, export_scope=export_scope) if proto: assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": for x in collection_list: if not export_scope or x.name.startswith(export_scope): getattr(col_def, kind).value.append( ops.strip_name_scope(x.name, export_scope)) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend( [compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning("Issue encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef. Note this is a warning " "and probably safe to ignore.\n%s", key, str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError( "Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and (not input_map or sorted( [compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError( "Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([ compat.as_str(v) for v in field.value if not input_map or v not in input_map ])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name(import_scope or "", mark_as_used=False) importer.import_graph_def(input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def import_scoped_meta_graph_with_return_elements( meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True), return_elements=None): """Imports graph from `MetaGraphDef` and returns vars and return elements. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. return_elements: A list of strings containing operation names in the `MetaGraphDef` that will be returned as `Operation` objects; and/or tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. Returns: A tuple of ( dictionary of all the `Variables` imported into the name scope, list of `Operation` or `Tensor` objects from the `return_elements` list). Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join( compat.as_str(v) for v in field.value if not input_map or v not in input_map)) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) imported_return_elements = importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list, return_elements=return_elements) # TensorFlow versions before 1.9 (not inclusive) exported SavedModels # without a VariableDef.trainable field set. tf_version = meta_graph_def.meta_info_def.tensorflow_version if not tf_version: variables_have_trainable = True else: variables_have_trainable = ( distutils_version.LooseVersion(tf_version) >= distutils_version.LooseVersion("1.9")) # Sort collections so we see TRAINABLE_VARIABLES first and can default these # variables to trainable if the value is not set in their VariableDef. sorted_collections = [] if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: sorted_collections.append( (ops.GraphKeys.TRAINABLE_VARIABLES, meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES])) for key, value in sorted(meta_graph_def.collection_def.items()): if key != ops.GraphKeys.TRAINABLE_VARIABLES: sorted_collections.append((key, value)) # Restores all the other collections. variable_objects = {} for key, col_def in sorted_collections: # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) # Temporary change to allow the TFMA evaluator to read metric variables # saved as a bytes list. # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. if key == ops.GraphKeys.METRIC_VARIABLES: # Metric variables will use the same proto functions as GLOBAL_VARIABLES from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) if not variables_have_trainable: # If the VariableDef proto does not contain a "trainable" # property because it was exported before that property was # added, we default it to whether the variable is in the # TRAINABLE_VARIABLES collection. We've sorted # TRAINABLE_VARIABLES to be first, so trainable variables will # be created from that collection. proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto( proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list, imported_return_elements
def strip_name_scope(name_scope): return ops.strip_name_scope(name_scope, export_scope)