def _init_from_proto(self, queue_runner_def, import_scope=None): """Create a QueueRunner from `QueueRunnerDef`. Args: queue_runner_def: Optional `QueueRunnerDef` protocol buffer. import_scope: Optional `string`. Name scope to add. """ assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef) g = ops.get_default_graph() self._queue = g.as_graph_element( ops.prepend_name_scope(queue_runner_def.queue_name, import_scope)) self._enqueue_ops = [g.as_graph_element( ops.prepend_name_scope(op, import_scope)) for op in queue_runner_def.enqueue_op_name] self._close_op = g.as_graph_element(ops.prepend_name_scope( queue_runner_def.close_op_name, import_scope)) self._cancel_op = g.as_graph_element(ops.prepend_name_scope( queue_runner_def.cancel_op_name, import_scope)) self._queue_closed_exception_types = tuple( errors.exception_type_from_error_code(code) for code in queue_runner_def.queue_closed_exception_types) # Legacy support for old QueueRunnerDefs created before this field # was added. if not self._queue_closed_exception_types: self._queue_closed_exception_types = (errors.OutOfRangeError,)
def _init_from_proto(self, queue_runner_def, import_scope=None): """Create a QueueRunner from `QueueRunnerDef`. Args: queue_runner_def: Optional `QueueRunnerDef` protocol buffer. import_scope: Optional `string`. Name scope to add. """ assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef) g = ops.get_default_graph() self._queue = g.as_graph_element( ops.prepend_name_scope(queue_runner_def.queue_name, import_scope)) self._enqueue_ops = [ g.as_graph_element(ops.prepend_name_scope(op, import_scope)) for op in queue_runner_def.enqueue_op_name ] self._close_op = g.as_graph_element( ops.prepend_name_scope(queue_runner_def.close_op_name, import_scope)) self._cancel_op = g.as_graph_element( ops.prepend_name_scope(queue_runner_def.cancel_op_name, import_scope)) self._queue_closed_exception_types = tuple( errors.exception_type_from_error_code(code) for code in queue_runner_def.queue_closed_exception_types) # Legacy support for old QueueRunnerDefs created before this field # was added. if not self._queue_closed_exception_types: self._queue_closed_exception_types = (errors.OutOfRangeError, )
def _init_from_proto(self, variable_def, import_scope=None): """Creates a new variable from `VariableDef` protocol buffer. Args: variable_def: `VariableDef` protocol buffer. import_scope: Optional `string`. Name scope to add. """ assert isinstance(variable_def, variable_pb2.VariableDef) # Create from variable_def. g = ops.get_default_graph() self._variable = g.as_graph_element( ops.prepend_name_scope(variable_def.variable_name, import_scope=import_scope)) self._initializer_op = g.as_graph_element( ops.prepend_name_scope(variable_def.initializer_name, import_scope=import_scope)) self._snapshot = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) if variable_def.HasField("save_slice_info_def"): self._save_slice_info = Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def) else: self._save_slice_info = None self._caching_device = None
def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") # Create from variable_def. g = ops.get_default_graph() self._handle = g.as_graph_element( ops.prepend_name_scope(variable_def.variable_name, import_scope=import_scope)) self._initialize_op = g.as_graph_element( ops.prepend_name_scope(variable_def.initializer_name, import_scope=import_scope)) if variable_def.snapshot_name: self._cached_value = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) else: self._cached_value = None if variable_def.HasField("save_slice_info_def"): self._save_slice_info = variables.Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def) else: self._save_slice_info = None self._caching_device = None self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" assert context.in_graph_mode() assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") # Create from variable_def. g = ops.get_default_graph() self._handle = g.as_graph_element( ops.prepend_name_scope(variable_def.variable_name, import_scope=import_scope)) self._handle_name = self._handle.name self._initializer_op = g.as_graph_element( ops.prepend_name_scope(variable_def.initializer_name, import_scope=import_scope)) if variable_def.snapshot_name: self._cached_value = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) else: self._cached_value = None if variable_def.HasField("save_slice_info_def"): self._save_slice_info = variables.Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def) else: self._save_slice_info = None self._caching_device = None self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) self._graph_element = self.value() self._constraint = None
def _init_from_proto(self, variable_def, import_scope=None): assert isinstance(variable_def, variable_pb2.VariableDef) # Create from variable_def. g = ops.get_default_graph() self._variable = g.as_graph_element( ops.prepend_name_scope(variable_def.variable_name, import_scope=import_scope)) self._initializer_op = g.as_graph_element( ops.prepend_name_scope(variable_def.initializer_name, import_scope=import_scope)) # Tests whether initial_value_name exists first for backwards compatibility. if (hasattr(variable_def, "initial_value_name") and variable_def.initial_value_name): self._initial_value = g.as_graph_element( ops.prepend_name_scope(variable_def.initial_value_name, import_scope=import_scope)) else: self._initial_value = None self._snapshot = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) if variable_def.HasField("save_slice_info_def"): self._save_slice_info = Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def, import_scope=import_scope) else: self._save_slice_info = None self._caching_device = None self._constraint = None
def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" # Note that init_from_proto is currently not supported in Eager mode. assert not context.executing_eagerly() self._in_graph_mode = True assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") # Create from variable_def. g = ops.get_default_graph() self._handle = g.as_graph_element( ops.prepend_name_scope( variable_def.variable_name, import_scope=import_scope)) self._shape = tensor_shape.TensorShape( self._handle.op.get_attr("shape")) self._handle_name = self._handle.name self._unique_id = self._handle_name self._initializer_op = g.as_graph_element( ops.prepend_name_scope( variable_def.initializer_name, import_scope=import_scope)) # Check whether initial_value_name exists for backwards compatibility. if (hasattr(variable_def, "initial_value_name") and variable_def.initial_value_name): self._initial_value = g.as_graph_element( ops.prepend_name_scope(variable_def.initial_value_name, import_scope=import_scope)) else: self._initial_value = None self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: snapshot = g.as_graph_element( ops.prepend_name_scope( variable_def.snapshot_name, import_scope=import_scope)) self._cached_value = snapshot while snapshot.op.type != "ReadVariableOp": snapshot = snapshot.op.inputs[0] self._graph_element = snapshot else: self._cached_value = None # Legacy case for protos without the snapshot name; assume it's the # following. self._graph_element = g.get_tensor_by_name( self._handle.op.name + "/Read/ReadVariableOp:0") if variable_def.HasField("save_slice_info_def"): self._save_slice_info = variables.Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def, import_scope=import_scope) else: self._save_slice_info = None self._caching_device = None self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) self._constraint = None self._cached_shape_as_list = None
def _build(self, checkpoint_path, build_save, build_restore): """Builds saver_def.""" if not context.executing_eagerly(): if self._is_built: return self._is_built = True if not self.saver_def or context.executing_eagerly(): if self._builder is None: # Attention: this is our target!! self._builder = SecureBulkSaverBuilder(self._write_version) if self._var_list is None: # pylint: disable=protected-access self._var_list = variables._all_saveable_objects() if not self._var_list: if self._allow_empty: self._is_empty = True return else: raise ValueError("No variables to save") self._is_empty = False self.saver_def = self._builder._build_internal( # pylint: disable=protected-access self._var_list, reshape=self._reshape, sharded=self._sharded, max_to_keep=self._max_to_keep, keep_checkpoint_every_n_hours=self. _keep_checkpoint_every_n_hours, name=self._name, restore_sequentially=self._restore_sequentially, filename=checkpoint_path, build_save=build_save, build_restore=build_restore) elif self.saver_def and self._name: # Since self._name is used as a name_scope by builder(), we are # overloading the use of this field to represent the "import_scope" as # well. self.saver_def.filename_tensor_name = ops.prepend_name_scope( self.saver_def.filename_tensor_name, self._name) self.saver_def.save_tensor_name = ops.prepend_name_scope( self.saver_def.save_tensor_name, self._name) self.saver_def.restore_op_name = ops.prepend_name_scope( self.saver_def.restore_op_name, self._name) self._check_saver_def() if not context.executing_eagerly(): # Updates next checkpoint time. # Set in __init__ when executing eagerly. self._next_checkpoint_time = ( time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
def __init__(self, full_name=None, full_shape=None, var_offset=None, var_shape=None, save_slice_info_def=None, import_scope=None): """Create a `SaveSliceInfo`. Args: full_name: Name of the full variable of which this `Variable` is a slice. full_shape: Shape of the full variable, as a list of int. var_offset: Offset of this `Variable` into the full variable, as a list of int. var_shape: Shape of this `Variable`, as a list of int. save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, recreates the SaveSliceInfo object its contents. `save_slice_info_def` and other arguments are mutually exclusive. import_scope: Optional `string`. Name scope to add. Only used when initializing from protocol buffer. """ if save_slice_info_def: assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) self.full_name = ops.prepend_name_scope( save_slice_info_def.full_name, import_scope=import_scope) self.full_shape = [i for i in save_slice_info_def.full_shape] self.var_offset = [i for i in save_slice_info_def.var_offset] self.var_shape = [i for i in save_slice_info_def.var_shape] else: self.full_name = full_name self.full_shape = full_shape self.var_offset = var_offset self.var_shape = var_shape
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection( key, ops.prepend_name_scope(value, scope))
def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" # Note that init_from_proto is currently not supported in Eager mode. assert context.in_graph_mode() self._in_graph_mode = True assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") # Create from variable_def. g = ops.get_default_graph() self._handle = g.as_graph_element( ops.prepend_name_scope( variable_def.variable_name, import_scope=import_scope)) self._shape = tensor_shape.TensorShape( self._handle.op.get_attr("shape")) self._handle_device = self._handle.device self._handle_name = self._handle.name self._initializer_op = g.as_graph_element( ops.prepend_name_scope( variable_def.initializer_name, import_scope=import_scope)) # Check whether initial_value_name exists for backwards compatibility. if (hasattr(variable_def, "initial_value_name") and variable_def.initial_value_name): self._initial_value = g.as_graph_element( ops.prepend_name_scope(variable_def.initial_value_name, import_scope=import_scope)) else: self._initial_value = None if variable_def.snapshot_name: self._cached_value = g.as_graph_element( ops.prepend_name_scope( variable_def.snapshot_name, import_scope=import_scope)) else: self._cached_value = None if variable_def.HasField("save_slice_info_def"): self._save_slice_info = variables.Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def, import_scope=import_scope) else: self._save_slice_info = None self._caching_device = None self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) self._graph_element = self.value() self._constraint = None
def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" # Note that init_from_proto is currently not supported in Eager mode. assert context.in_graph_mode() self._in_graph_mode = True assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") # Create from variable_def. g = ops.get_default_graph() self._handle = g.as_graph_element( ops.prepend_name_scope(variable_def.variable_name, import_scope=import_scope)) self._shape = tensor_shape.TensorShape( self._handle.op.get_attr("shape")) self._handle_device = self._handle.device self._handle_name = self._handle.name self._initializer_op = g.as_graph_element( ops.prepend_name_scope(variable_def.initializer_name, import_scope=import_scope)) # Check whether initial_value_name exists for backwards compatibility. if (hasattr(variable_def, "initial_value_name") and variable_def.initial_value_name): self._initial_value = g.as_graph_element( ops.prepend_name_scope(variable_def.initial_value_name, import_scope=import_scope)) else: self._initial_value = None if variable_def.snapshot_name: self._cached_value = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) else: self._cached_value = None if variable_def.HasField("save_slice_info_def"): self._save_slice_info = variables.Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def, import_scope=import_scope) else: self._save_slice_info = None self._caching_device = None self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) self._graph_element = self.value() self._constraint = None
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection(key, ops.prepend_name_scope(value, scope))
def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" # Note that init_from_proto is currently not supported in Eager mode. assert not context.executing_eagerly() self._in_graph_mode = True assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError( "Trying to restore Variable as EmbeddingVariable.") # Create from variable_def. g = ops.get_default_graph() self._handle = g.as_graph_element( ops.prepend_name_scope(variable_def.variable_name, import_scope=import_scope)) self._graph_shape = tensor_shape.TensorShape( self._handle.op.get_attr("shape")) self._handle_device = self._handle.device self._handle_name = self._handle.name self._initializer_op = g.as_graph_element( ops.prepend_name_scope(variable_def.initializer_name, import_scope=import_scope)) self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: self._cached_value = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) else: self._cached_value = None if variable_def.HasField("save_slice_info_def"): self._save_slice_info = variables.Variable.SaveSliceInfo( save_slice_info_def=variable_def.save_slice_info_def) else: self._save_slice_info = None self._caching_device = None self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) self._invalid_key = -1 self._initial_value = ops.convert_to_tensor([0], name="initial_value", dtype=self._dtype) self._invalid_key_type = dtypes.as_dtype( self._handle.op.get_attr("Tkeys")) self._graph_element = None self._constraint = None
def _remap_feed(self, feed, feed_val=None): """ Remap the feeds to the right element in the transformed graph. For example, there are N copies of a placeholder for N replicas and we have to feed all of them with tensors. Args: feed: feed graph element or name feed_val: feed value Returns: List of (new_feed, new_feed_value) pairs """ feed_name = feed if isinstance(feed, str) else feed.name try: transformed_feeds = [ self._graph_item.graph.as_graph_element(feed_name) ] except KeyError: transformed_feeds = [ self._graph_item.graph.as_graph_element( ops.prepend_name_scope(feed_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] num_replicated_feeds = self._graph_transformer.num_local_replicas feed = feed if not isinstance(feed, str) else transformed_feeds[0] def expand_feed_val(feed_val, feed=feed): """Given a original feed or replicated feed, expand the feed value.""" # If we have replicated placeholders with undefined (polymorphic) shape, we split the feed_val across it; # otherwise we feed all replicated placeholders the same feed_val polymorphic_dim = self._polymorphic_dim(feed) if polymorphic_dim: feed_vals = np.array_split(np.asarray(feed_val), num_replicated_feeds, axis=polymorphic_dim) else: feed_vals = [feed_val for _ in range(num_replicated_feeds)] return feed_vals if feed_val is not None: feed_vals = expand_feed_val(feed_val) transformed_feeds = list(zip(transformed_feeds, feed_vals)) return transformed_feeds, expand_feed_val
def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None): """Returns the element in the graph described by a TensorInfo proto. Args: tensor_info: A TensorInfo proto describing an Op or Tensor by name. graph: The tf.Graph in which tensors are looked up. If None, the current default graph is used. import_scope: If not None, names in `tensor_info` are prefixed with this string before lookup. Returns: Op or tensor in `graph` described by `tensor_info`. Raises: KeyError: If `tensor_info` does not correspond to an op or tensor in `graph` """ graph = graph or ops.get_default_graph() return graph.as_graph_element( ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
def import_scoped_meta_graph_with_return_elements( meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True), return_elements=None): """Imports graph from `MetaGraphDef` and returns vars and return elements. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. return_elements: A list of strings containing operation names in the `MetaGraphDef` that will be returned as `Operation` objects; and/or tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. Returns: A tuple of ( dictionary of all the `Variables` imported into the name scope, list of `Operation` or `Tensor` objects from the `return_elements` list). Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError( "Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and (not input_map or sorted( [compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError( "Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join( compat.as_str(v) for v in field.value if not input_map or v not in input_map)) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name(import_scope or "", mark_as_used=False) imported_return_elements = importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list, return_elements=return_elements) # TensorFlow versions before 1.9 (not inclusive) exported SavedModels # without a VariableDef.trainable field set. tf_version = meta_graph_def.meta_info_def.tensorflow_version if not tf_version: variables_have_trainable = True else: variables_have_trainable = (packaging_version.parse(tf_version) >= packaging_version.parse("1.9")) # Sort collections so we see TRAINABLE_VARIABLES first and can default these # variables to trainable if the value is not set in their VariableDef. sorted_collections = [] if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES, meta_graph_def.collection_def[ ops.GraphKeys.TRAINABLE_VARIABLES])) for key, value in sorted(meta_graph_def.collection_def.items()): if key != ops.GraphKeys.TRAINABLE_VARIABLES: sorted_collections.append((key, value)) # Restores all the other collections. variable_objects = {} for key, col_def in sorted_collections: # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) # Temporary change to allow the TFMA evaluator to read metric variables # saved as a bytes list. # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. if key == ops.GraphKeys.METRIC_VARIABLES: # Metric variables will use the same proto functions as GLOBAL_VARIABLES from_proto = ops.get_from_proto_function( ops.GraphKeys.GLOBAL_VARIABLES) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) if not variables_have_trainable: # If the VariableDef proto does not contain a "trainable" # property because it was exported before that property was # added, we default it to whether the variable is in the # TRAINABLE_VARIABLES collection. We've sorted # TRAINABLE_VARIABLES to be first, so trainable variables will # be created from that collection. proto.trainable = ( key == ops.GraphKeys.TRAINABLE_VARIABLES) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list, imported_return_elements
def _get_tensor(name): return graph.get_tensor_by_name( ops.prepend_name_scope(name, import_scope=import_scope))
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto( proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.in_eager_mode(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE (opensource): This force conversion is to work around the fact id:3223 gh:3224 # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def copy_ops_meta_graph(op_list, from_scope, to_scope, replace=None): """Copies a list of `Operation`s from one scope to another, with variables shared between them. Args: op_list: A list of `Operation` objects to be copied. from_scope: `String` name scope containing the ops to be copied. to_scope: `String` name scope under which the copied ops will reside. replace: A dictionary containing the mapping from input Tensors of these ops to their replacements. Returns: A dictionary containing the mapping from original ops to their copies and a dictionary of `Variables` that have been copied into `to_scope`. Raises: ValueError: If `from_scope` and `to_scope` are the same. """ if from_scope == to_scope: raise ValueError("'from_scope' and 'to_scope' need to be different " "when performing copy in the same graph.") op_list = set(op_list) op_names = set(op.name for op in op_list) op_outputs = set() for op in op_list: if not op.name.startswith(from_scope): raise ValueError("The Operation (%s) to copy is not under " "'from_scope'." % op.name) op_outputs.update(set(op.outputs)) input_map = {} as_unbound_inputs = [] for op in op_list: for tensor in op.inputs: if not (tensor in op_outputs) or (tensor in replace): name = tensor.name[:-2] if tensor.name[ -2:] == ":0" else tensor.name as_unbound_inputs.append(name) if tensor in replace: input_map[_unbound_name(name)] = replace[tensor] else: input_map[_unbound_name(name)] = tensor for dep in op.control_inputs: if dep not in op_list: name = "^" + dep.name as_unbound_inputs.append(name) input_map[_unbound_name(name)] = dep for name in op.colocation_groups(): if name[5:] not in op_names: as_unbound_inputs.append(name) input_map[_unbound_name(name)] = ops.get_default_graph(). \ as_graph_element(name[5:]) orig_meta_graph = export_ops_meta_graph( op_list, export_scope=from_scope, as_unbound_inputs=as_unbound_inputs) _ = import_scoped_meta_graph(orig_meta_graph, import_scope=to_scope, input_map=input_map) copied_ops = {} for op in op_list: new_op_name = ops.prepend_name_scope( ops.strip_name_scope(op.name, from_scope), to_scope) new_op = ops.get_default_graph().as_graph_element(new_op_name, allow_tensor=False) copied_ops[op] = new_op return copied_ops
def __iadd__(self, other): logging.log_first_n( logging.WARN, return self + other def __isub__(self, other): logging.log_first_n( logging.WARN, return self - other def __imul__(self, other): logging.log_first_n( logging.WARN, return self * other def __idiv__(self, other): logging.log_first_n( logging.WARN, return self / other def __itruediv__(self, other): logging.log_first_n( logging.WARN, return self / other def __irealdiv__(self, other): logging.log_first_n( logging.WARN, return self / other def __ipow__(self, other): logging.log_first_n( logging.WARN, return self ** other class SaveSliceInfo(object): def __init__(self, full_name=None, full_shape=None, var_offset=None, var_shape=None, save_slice_info_def=None, import_scope=None): if save_slice_info_def: assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) self.full_name = ops.prepend_name_scope( save_slice_info_def.full_name, import_scope=import_scope) self.full_shape = [i for i in save_slice_info_def.full_shape] self.var_offset = [i for i in save_slice_info_def.var_offset] self.var_shape = [i for i in save_slice_info_def.var_shape] else: self.full_name = full_name self.full_shape = full_shape self.var_offset = var_offset self.var_shape = var_shape @property def spec(self): full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " " sl_spec = ":".join([ "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape) ]) return full_shape_str + sl_spec def to_proto(self, export_scope=None): if (export_scope is None or self.full_name.startswith(export_scope)): save_slice_info_def = variable_pb2.SaveSliceInfoDef() save_slice_info_def.full_name = ops.strip_name_scope( self.full_name, export_scope) for i in self.full_shape: save_slice_info_def.full_shape.append(i) for i in self.var_offset: save_slice_info_def.var_offset.append(i) for i in self.var_shape: save_slice_info_def.var_shape.append(i) return save_slice_info_def else: return None def _set_save_slice_info(self, save_slice_info): self._save_slice_info = save_slice_info def _get_save_slice_info(self): return self._save_slice_info
def prepend_name_scope(name_scope): return ops.prepend_name_scope(name_scope, import_scope)
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs"): """Recreates a`Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates all the collections, and returns a saver constructed from the `saver_def` field. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=import_scope)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, import_scope)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, import_scope)) var_list = {} variables = graph.get_collection(ops.GraphKeys.VARIABLES, scope=import_scope) for v in variables: var_list[ops.strip_name_scope(v.name, import_scope)] = v return var_list
def export_scoped_meta_graph(filename=None, graph_def=None, graph=None, export_scope=None, as_text=False, unbound_inputs_col_name="unbound_inputs", clear_devices=False, saver_def=None, clear_extraneous_savers=False, strip_default_attrs=False, save_debug_info=False, **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph. Args: filename: Optional filename including the path for writing the generated `MetaGraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer. graph: The `Graph` to export. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be stripped from the node definitions for easy import later into new name scopes. If `None`, the whole graph is exported. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. unbound_inputs_col_name: Optional `string`. If provided, a string collection with the given name will be added to the returned `MetaGraphDef`, containing the names of tensors that must be remapped when importing the `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information before exporting the graph. saver_def: `SaverDef` protocol buffer. clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with the provided SaverDef. strip_default_attrs: Set to true if default valued attributes must be removed while exporting the GraphDef. save_debug_info: If `True`, save the GraphDebugInfo to a separate file, which in the same directory of filename and with `_debug` added before the file extension. **kwargs: Optional keyed arguments, including meta_info_def and collection_list. Returns: A `MetaGraphDef` proto and dictionary of `Variables` in the exported name scope. Raises: ValueError: When the `GraphDef` is larger than 2GB. ValueError: When executing in Eager mode and either `graph_def` or `graph` is undefined. """ if context.executing_eagerly() and not (graph_def is not None and graph is not None): raise ValueError( "Exporting/importing meta graphs is not supported when " "Eager Execution is enabled.") graph = graph or ops.get_default_graph() exclude_nodes = None unbound_inputs = [] if export_scope or clear_extraneous_savers or clear_devices: if graph_def: new_graph_def = graph_pb2.GraphDef() new_graph_def.versions.CopyFrom(graph_def.versions) new_graph_def.library.CopyFrom(graph_def.library) if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes( graph_def, saver_def) for node_def in graph_def.node: if _should_include_node(node_def.name, export_scope, exclude_nodes): new_node_def = _node_def(node_def, export_scope, unbound_inputs, clear_devices=clear_devices) new_graph_def.node.extend([new_node_def]) graph_def = new_graph_def else: # Only do this complicated work if we want to remove a name scope. graph_def = graph_pb2.GraphDef() # pylint: disable=protected-access graph_def.versions.CopyFrom(graph.graph_def_versions) bytesize = 0 if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes( graph.as_graph_def(), saver_def) for key in sorted(graph._nodes_by_id): if _should_include_node(graph._nodes_by_id[key].name, export_scope, exclude_nodes): value = graph._nodes_by_id[key] # pylint: enable=protected-access node_def = _node_def(value.node_def, export_scope, unbound_inputs, clear_devices=clear_devices) graph_def.node.extend([node_def]) if value.outputs: assert "_output_shapes" not in graph_def.node[-1].attr graph_def.node[-1].attr[ "_output_shapes"].list.shape.extend([ output.get_shape().as_proto() for output in value.outputs ]) bytesize += value.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB. " f"Received size: {bytesize}.") graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access # It's possible that not all the inputs are in the export_scope. # If we would like such information included in the exported meta_graph, # add them to a special unbound_inputs collection. if unbound_inputs_col_name: # Clears the unbound_inputs collections. graph.clear_collection(unbound_inputs_col_name) for k in unbound_inputs: graph.add_to_collection(unbound_inputs_col_name, k) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=export_scope) for v in variables: if _should_include_node(v, export_scope, exclude_nodes): var_list[ops.strip_name_scope(v.name, export_scope)] = v scoped_meta_graph_def = create_meta_graph_def( graph_def=graph_def, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes, clear_extraneous_savers=clear_extraneous_savers, saver_def=saver_def, strip_default_attrs=strip_default_attrs, **kwargs) if filename: graph_io.write_graph(scoped_meta_graph_def, os.path.dirname(filename), os.path.basename(filename), as_text=as_text) if save_debug_info: name, _ = os.path.splitext(filename) debug_filename = "{name}{ext}".format(name=name, ext=".debug") # Gets the operation from the graph by the name. Excludes variable nodes, # so only the nodes in the frozen models are included. # TODO(liufengdb): fix this for functions. ops_to_export = [] for node in scoped_meta_graph_def.graph_def.node: scoped_op_name = ops.prepend_name_scope( node.name, export_scope) ops_to_export.append( ("", graph.get_operation_by_name(scoped_op_name))) graph_debug_info = error_interpolation.create_graph_debug_info_def( ops_to_export) graph_io.write_graph(graph_debug_info, os.path.dirname(debug_filename), os.path.basename(debug_filename), as_text=as_text) return scoped_meta_graph_def, var_list
def _remap_fetch(self, fetch): """ Remap the user-provided fetches to the right list of fetches after graph transformations. Cases: * If original fetch exists (which is not affected by graph transformation), fetch the original. * Otherwise, for fetches that are train_ops, fetch them on all replicas; * for other fetches, only fetch it on master replica. * For example, for partitioned vars, it corresponds to the concat one as_tensor on the first replica. """ _remap_element = self._remap_element fetch_type = type(fetch) fetch_name = fetch if isinstance(fetch, str) else fetch.name contract_fn = lambda fetched_vals: fetched_vals[0] # noqa: E731 try: transformed_fetch = [_remap_element(fetch_type, fetch_name)] except KeyError: master_replica_name = ops.prepend_name_scope( fetch_name, replica_prefix(0)) master_replica_fetch = _remap_element(fetch_type, master_replica_name) polymorphic_dim = self._polymorphic_dim(master_replica_fetch) def is_train_op(op): # In TF2: train_op as AssignAddVariableOp # In TF1 (being deprecated): no_op with a groups of stateful ops as control dependencies # TODO(unless deprecating): make the checking as strict as possible return isinstance( op, ops.Operation) and (op.op_def.is_stateful or op.op_def.name == 'NoOp') if is_train_op(master_replica_fetch): transformed_fetch = [ _remap_element( fetch_type, ops.prepend_name_scope(fetch_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] #################################################################### # # For Debugging Local Replicas #################################################################### # transformed_fetch = [ # self._graph_item.graph.as_graph_element('AutoDist-Replica-0/emb/part_0_take_grad') # ] # transformed_fetch = [ # _remap_element(ops.Tensor, ops.prepend_name_scope( # 'Mean:0', # replica_prefix(i))) # for i in range(self._graph_transformer.num_local_replicas) # ] # transformed_fetch = [_remap_element(ops.Tensor, # ops.prepend_name_scope( # 'sampled_softmax_loss/embedding_lookup:0', # replica_prefix(1) # ) # )] #################################################################### logging.debug('Fetch mapped from {} to {}'.format( fetch, transformed_fetch)) elif polymorphic_dim: transformed_fetch = [ _remap_element( fetch_type, ops.prepend_name_scope(fetch_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] contract_fn = lambda fetch_vals: np.concatenate( fetch_vals, axis=polymorphic_dim) # noqa: E731 else: transformed_fetch = [master_replica_fetch] return transformed_fetch, contract_fn
def partially_apply_saved_transform(saved_model_dir, input_tensors): """Apply a transform graph, represented as a SavedModel, to existing Tensors. This adds nodes to a graph that already contains Tensors representing the inputs. These input Tensors may be placeholders that will be fed when the graph is executed, or may be the outputs of some Ops. Most typically, the input Tensors are reading and/or parsing Ops, but they could be anything-- including the outputs of a prior application of this function using another transform graph. This function operates on the default Graph in the default Session, and so must be called within a context where these are provided. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). input_tensors: a dict of logical name to Tensor. The logical names must be a subset of those in the input signature of the transform graph, and the corresponding Tensors must have the expected types and shapes. Returns: A pair of (unbound_inputs, outputs) where unbound_inputs is a dict of logical name to Tensors that are yet to be mapped or fed, and outputs is a dict of logical name to Tensor, as provided by the output signature of the transform graph Raises: ValueError: if the provided input_tensors dict has keys that are not part of the input signature, or any of the provided inputs have the wrong type or shape. RuntimeError: if there is no default graph available to which to apply the transform. """ decomposed_input_tensors = _decompose_sparse_tensors(input_tensors) meta_graph_def, input_signature, output_signature = ( _load_transform_saved_model(saved_model_dir)) # Check for inputs that were not part of the input signature. unexpected_inputs = (set(decomposed_input_tensors.keys()) - set(input_signature.keys())) if unexpected_inputs: raise ValueError('Unexpected inputs ' 'to transform: {}'.format(unexpected_inputs)) # Create a map from tensor names in the graph to be imported, to the tensors # specified in `input_tensors`. input_map = { input_signature[decomposed_logical_name]: decomposed_input_tensors[decomposed_logical_name] for decomposed_logical_name in decomposed_input_tensors } graph = tf.get_default_graph() if graph is None: raise RuntimeError('apply_saved_transform() requires a default graph.') # unique_name may produce e.g. transform_5. The result has no trailing slash. scope = graph.unique_name('transform', mark_as_used=False) # Load the transform graph, applying it to existing Tensors via input_map. # Throws ValueError if the input_map gives mismatched types or shapes. saver = tf_saver.import_meta_graph(meta_graph_def, import_scope=scope, input_map=input_map) if saver: tf.logging.warn( 'Transform graphs should not have saved Variables, but this ' 'one does. Variable values will *not* be restored.') # Add computed output tensors to the output. There are two cases. When the # output is not in the input_map, then we look up the tensor in the imported # graph by preprending the import scope and looking up the tensor by name. # This will fail if the expected output tensor is not now in the graph # under the expected name scope. When the output is in the input map, then # that tensor will have been re-mapped so we use the tensor given in the # input_map. def lookup_remapped_tensor(tensor_name): if tensor_name in input_map: return input_map[tensor_name] else: return graph.get_tensor_by_name( ops.prepend_name_scope(tensor_name, scope)) decomposed_output_tensors = { decomposed_logical_name: lookup_remapped_tensor(tensor_name) for decomposed_logical_name, tensor_name in output_signature.items() } # Do the same for input tensors, where we assume such tensors are not in the # input_map since identical tensors in an input_map would be an error. decomposed_unbound_input_tensors = { decomposed_logical_name: graph.get_tensor_by_name(ops.prepend_name_scope(tensor_name, scope)) for decomposed_logical_name, tensor_name in input_signature.items() if decomposed_logical_name not in decomposed_input_tensors } outputs = _recompose_sparse_tensors(decomposed_output_tensors) unbound_inputs = _recompose_sparse_tensors( decomposed_unbound_input_tensors) return unbound_inputs, outputs
def lookup_remapped_tensor(tensor_name): if tensor_name in input_map: return input_map[tensor_name] else: return graph.get_tensor_by_name( ops.prepend_name_scope(tensor_name, scope))
def _partially_apply_saved_transform_impl( saved_model_dir, logical_input_map, tensor_replacement_map=None, fetch_tensor_names=None): """Shared code for partially_apply_saved_transform and fetch_tensor_values. This adds nodes to a graph that already contains Tensors representing the inputs. These input Tensors may be placeholders that will be fed when the graph is executed, or may be the outputs of some Ops. Most typically, the input Tensors are reading and/or parsing Ops, but they could be anything-- including the outputs of a prior application of this function using another transform graph. This function operates on the default Graph in the default Session, and so must be called within a context where these are provided. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). logical_input_map: a dict of logical name to Tensor. The logical names must be a subset of those in the input signature of the transform graph, and the corresponding Tensors must have the expected types and shapes. tensor_replacement_map: a dict of tensor names to `Tensors`. fetch_tensor_names: a list of tensor names. Returns: A tuple of (unbound_inputs, outputs, fetched_tensors) where unbound_inputs is a dict of logical name to Tensors that are yet to be mapped or fed, outputs is a dict of logical name to Tensor, as provided by the output signature of the transform graph, and fetched_tensors is a dict of tensor names to `Tensor`s where the tensor names are the names given by `fetched_tensor_names`. Raises: ValueError: if the provided input_tensors dict has keys that are not part of the input signature, or any of the provided inputs have the wrong type or shape. RuntimeError: if there is no default graph available to which to apply the transform. """ graph = tf.get_default_graph() if graph is None: raise RuntimeError('apply_saved_transform() requires a default graph.') decomposed_input_tensors = _decompose_sparse_tensors(logical_input_map) meta_graph_def, input_signature, output_signature, asset_path_dict = ( _load_transform_saved_model(saved_model_dir)) asset_tensor_dict = {k: ops.convert_to_tensor(v) for k, v in asset_path_dict.items()} # Check for inputs that were not part of the input signature. unexpected_inputs = (set(six.iterkeys(decomposed_input_tensors)) - set(six.iterkeys(input_signature))) if unexpected_inputs: raise ValueError('Unexpected inputs ' 'to transform: {}'.format(unexpected_inputs)) # Create a map from tensor names in the graph to be imported, to the tensors # specified in `input_tensors`. input_map = { input_signature[decomposed_logical_name]: decomposed_input_tensors[decomposed_logical_name] for decomposed_logical_name in decomposed_input_tensors} input_map.update(asset_tensor_dict) if tensor_replacement_map: input_map.update(tensor_replacement_map) # unique_name may produce e.g. transform_5. The result has no trailing slash. scope = graph.unique_name('transform', mark_as_used=False) # unique_name returns an "absolute" name while we want a name relative to the # current scope. Therefore, we check if the current name stack is non-empty, # and if so, strip out the existing name scope. if graph.get_name_scope(): current_name_scope = graph.get_name_scope() + '/' assert scope.startswith(current_name_scope) import_scope = scope[len(current_name_scope):] else: import_scope = scope # Save the ASSET_FILEPATHS before importing the MetaGraphDef current_assets = graph.get_collection(tf.GraphKeys.ASSET_FILEPATHS) # Warn user if meta_graph_def has saved variables if tf.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: trainable_vars = meta_graph_def.collection_def[ tf.GraphKeys.TRAINABLE_VARIABLES].bytes_list.value if trainable_vars: raise ValueError( 'The SavedModel contained trainable variables {}. Because this ' 'function is typically called in the input_fn, trainable variables ' 'are disallowed'.format(trainable_vars)) # Load the transform graph, applying it to existing Tensors via input_map. # Throws ValueError if the input_map gives mismatched types or shapes. saver = tf_saver.import_meta_graph(meta_graph_def, import_scope=import_scope, input_map=input_map) # Wipe out AssetFileDef collection; it is obsolete after loading graph.clear_collection(tf.saved_model.constants.ASSETS_KEY) # The import may have added Tensors to the ASSET_FILEPATHS collection that # were substituted via input_map. To account for this, wipe out the # collection, restore the preexisting collection values, and then write in # the new substituted Tensors. graph.clear_collection(tf.GraphKeys.ASSET_FILEPATHS) for asset_path_tensor in current_assets: graph.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_path_tensor) for asset_path_tensor in asset_tensor_dict.values(): graph.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_path_tensor) if saver: checkpoint_path = os.path.join( tf.compat.as_bytes(saved_model_dir), tf.compat.as_bytes(tf.saved_model.constants.VARIABLES_DIRECTORY), tf.compat.as_bytes(tf.saved_model.constants.VARIABLES_FILENAME)) # We can't use the scope rename from init_from_checkpoint because it relies # on var scopes not rebuilt by import_meta_graph. So we need to construct it # explicitly by iterating over the variables. var_map = {} for var in tf.global_variables(): if var.op.name.startswith(scope): var_map[var.op.name[len(scope)+1:]] = var if var_map: tf.train.init_from_checkpoint(checkpoint_path, var_map) # Add computed output tensors to the output. There are two cases. When the # output is not in the input_map, then we look up the tensor in the imported # graph by prepending the import scope and looking up the tensor by name. # This will fail if the expected output tensor is not now in the graph # under the expected name scope. When the output is in the input map, then # that tensor will have been re-mapped so we use the tensor given in the # input_map. def lookup_remapped_tensor(tensor_name): if tensor_name in input_map: return input_map[tensor_name] else: return graph.get_tensor_by_name( ops.prepend_name_scope(tensor_name, scope)) decomposed_output_tensors = { decomposed_logical_name: lookup_remapped_tensor(tensor_name) for decomposed_logical_name, tensor_name in six.iteritems(output_signature) } # Do the same for input tensors, where we assume such tensors are not in the # input_map since identical tensors in an input_map would be an error. decomposed_unbound_input_tensors = { decomposed_logical_name: graph.get_tensor_by_name( ops.prepend_name_scope(tensor_name, scope)) for decomposed_logical_name, tensor_name in six.iteritems(input_signature) if decomposed_logical_name not in decomposed_input_tensors } if fetch_tensor_names is None: fetch_tensor_names = [] fetched_tensors = { name: lookup_remapped_tensor(name) for name in fetch_tensor_names} outputs = _recompose_sparse_tensors(decomposed_output_tensors) unbound_inputs = _recompose_sparse_tensors(decomposed_unbound_input_tensors) return unbound_inputs, outputs, fetched_tensors
def _partially_apply_saved_transform_impl(saved_model_dir, logical_input_map, tensor_replacement_map=None, fetch_tensor_names=None): """Shared code for partially_apply_saved_transform and fetch_tensor_values. This adds nodes to a graph that already contains Tensors representing the inputs. These input Tensors may be placeholders that will be fed when the graph is executed, or may be the outputs of some Ops. Most typically, the input Tensors are reading and/or parsing Ops, but they could be anything-- including the outputs of a prior application of this function using another transform graph. This function operates on the default Graph in the default Session, and so must be called within a context where these are provided. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). logical_input_map: a dict of logical name to Tensor. The logical names must be a subset of those in the input signature of the transform graph, and the corresponding Tensors must have the expected types and shapes. tensor_replacement_map: a dict of tensor names to `Tensors`. fetch_tensor_names: a list of tensor names. Returns: A tuple of (unbound_inputs, outputs, fetched_tensors) where unbound_inputs is a dict of logical name to Tensors that are yet to be mapped or fed, outputs is a dict of logical name to Tensor, as provided by the output signature of the transform graph, and fetched_tensors is a dict of tensor names to `Tensor`s where the tensor names are the names given by `fetched_tensor_names`. Raises: ValueError: if the provided input_tensors dict has keys that are not part of the input signature, or any of the provided inputs have the wrong type or shape. RuntimeError: if there is no default graph available to which to apply the transform. """ graph = tf.get_default_graph() if graph is None: raise RuntimeError('apply_saved_transform() requires a default graph.') decomposed_input_tensors = _decompose_sparse_tensors(logical_input_map) meta_graph_def, input_signature, output_signature, asset_path_dict = ( _load_transform_saved_model(saved_model_dir)) asset_tensor_dict = { k: ops.convert_to_tensor(v) for k, v in asset_path_dict.items() } # Check for inputs that were not part of the input signature. unexpected_inputs = (set(six.iterkeys(decomposed_input_tensors)) - set(six.iterkeys(input_signature))) if unexpected_inputs: raise ValueError('Unexpected inputs ' 'to transform: {}'.format(unexpected_inputs)) # Create a map from tensor names in the graph to be imported, to the tensors # specified in `input_tensors`. input_map = { input_signature[decomposed_logical_name]: decomposed_input_tensors[decomposed_logical_name] for decomposed_logical_name in decomposed_input_tensors } input_map.update(asset_tensor_dict) if tensor_replacement_map: input_map.update(tensor_replacement_map) # unique_name may produce e.g. transform_5. The result has no trailing slash. scope = graph.unique_name('transform', mark_as_used=False) # Save the ASSET_FILEPATHS before importing the MetaGraphDef current_assets = graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS) # Load the transform graph, applying it to existing Tensors via input_map. # Throws ValueError if the input_map gives mismatched types or shapes. saver = tf_saver.import_meta_graph(meta_graph_def, import_scope=scope, input_map=input_map) for op in graph.get_operations(): # pylint: disable=protected-access if op.type == b'Where' and 'T' in op._node_def.attr: del op._node_def.attr['T'] # Wipe out AssetFileDef collection; it is obsolete after loading graph.clear_collection(tf.saved_model.constants.ASSETS_KEY) # The import may have added Tensors to the ASSET_FILEPATHS collection that # were substituted via input_map. To account for this, wipe out the # collection, restore the preexisting collection values, and then write in # the new substituted Tensors. graph.clear_collection(ops.GraphKeys.ASSET_FILEPATHS) for asset_path_tensor in current_assets: graph.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_path_tensor) for asset_path_tensor in asset_tensor_dict.values(): graph.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_path_tensor) if saver: tf.logging.warn( 'Transform graphs should not have saved Variables, but this ' 'one does. Variable values will *not* be restored.') # Add computed output tensors to the output. There are two cases. When the # output is not in the input_map, then we look up the tensor in the imported # graph by prepending the import scope and looking up the tensor by name. # This will fail if the expected output tensor is not now in the graph # under the expected name scope. When the output is in the input map, then # that tensor will have been re-mapped so we use the tensor given in the # input_map. def lookup_remapped_tensor(tensor_name): if tensor_name in input_map: return input_map[tensor_name] else: return graph.get_tensor_by_name( ops.prepend_name_scope(tensor_name, scope)) decomposed_output_tensors = { decomposed_logical_name: lookup_remapped_tensor(tensor_name) for decomposed_logical_name, tensor_name in six.iteritems( output_signature) } # Do the same for input tensors, where we assume such tensors are not in the # input_map since identical tensors in an input_map would be an error. decomposed_unbound_input_tensors = { decomposed_logical_name: graph.get_tensor_by_name(ops.prepend_name_scope(tensor_name, scope)) for decomposed_logical_name, tensor_name in six.iteritems( input_signature) if decomposed_logical_name not in decomposed_input_tensors } if fetch_tensor_names is None: fetch_tensor_names = [] fetched_tensors = { name: lookup_remapped_tensor(name) for name in fetch_tensor_names } outputs = _recompose_sparse_tensors(decomposed_output_tensors) unbound_inputs = _recompose_sparse_tensors( decomposed_unbound_input_tensors) return unbound_inputs, outputs, fetched_tensors
def export_scoped_meta_graph(filename=None, graph_def=None, graph=None, export_scope=None, as_text=False, unbound_inputs_col_name="unbound_inputs", clear_devices=False, saver_def=None, clear_extraneous_savers=False, strip_default_attrs=False, save_debug_info=False, **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph. Args: filename: Optional filename including the path for writing the generated `MetaGraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer. graph: The `Graph` to export. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be stripped from the node definitions for easy import later into new name scopes. If `None`, the whole graph is exported. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. unbound_inputs_col_name: Optional `string`. If provided, a string collection with the given name will be added to the returned `MetaGraphDef`, containing the names of tensors that must be remapped when importing the `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information before exporting the graph. saver_def: `SaverDef` protocol buffer. clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with the provided SaverDef. strip_default_attrs: Set to true if default valued attributes must be removed while exporting the GraphDef. save_debug_info: If `True`, save the GraphDebugInfo to a separate file, which in the same directory of filename and with `_debug` added before the file extension. **kwargs: Optional keyed arguments, including meta_info_def and collection_list. Returns: A `MetaGraphDef` proto and dictionary of `Variables` in the exported name scope. Raises: ValueError: When the `GraphDef` is larger than 2GB. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "Eager Execution is enabled.") graph = graph or ops.get_default_graph() exclude_nodes = None unbound_inputs = [] if export_scope or clear_extraneous_savers or clear_devices: if graph_def: new_graph_def = graph_pb2.GraphDef() new_graph_def.versions.CopyFrom(graph_def.versions) new_graph_def.library.CopyFrom(graph_def.library) if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) for node_def in graph_def.node: if _should_include_node(node_def.name, export_scope, exclude_nodes): new_node_def = _node_def(node_def, export_scope, unbound_inputs, clear_devices=clear_devices) new_graph_def.node.extend([new_node_def]) graph_def = new_graph_def else: # Only do this complicated work if we want to remove a name scope. graph_def = graph_pb2.GraphDef() # pylint: disable=protected-access graph_def.versions.CopyFrom(graph.graph_def_versions) bytesize = 0 if clear_extraneous_savers: exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), saver_def) for key in sorted(graph._nodes_by_id): if _should_include_node(graph._nodes_by_id[key].name, export_scope, exclude_nodes): value = graph._nodes_by_id[key] # pylint: enable=protected-access node_def = _node_def(value.node_def, export_scope, unbound_inputs, clear_devices=clear_devices) graph_def.node.extend([node_def]) if value.outputs: assert "_output_shapes" not in graph_def.node[-1].attr graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ output.get_shape().as_proto() for output in value.outputs]) bytesize += value.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access # It's possible that not all the inputs are in the export_scope. # If we would like such information included in the exported meta_graph, # add them to a special unbound_inputs collection. if unbound_inputs_col_name: # Clears the unbound_inputs collections. graph.clear_collection(unbound_inputs_col_name) for k in unbound_inputs: graph.add_to_collection(unbound_inputs_col_name, k) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=export_scope) for v in variables: if _should_include_node(v, export_scope, exclude_nodes): var_list[ops.strip_name_scope(v.name, export_scope)] = v scoped_meta_graph_def = create_meta_graph_def( graph_def=graph_def, graph=graph, export_scope=export_scope, exclude_nodes=exclude_nodes, clear_extraneous_savers=clear_extraneous_savers, saver_def=saver_def, strip_default_attrs=strip_default_attrs, **kwargs) if filename: graph_io.write_graph( scoped_meta_graph_def, os.path.dirname(filename), os.path.basename(filename), as_text=as_text) if save_debug_info: name, _ = os.path.splitext(filename) debug_filename = "{name}{ext}".format(name=name, ext=".debug") # Gets the operation from the graph by the name. ops_to_export = {} for node in scoped_meta_graph_def.graph_def.node: scoped_op_name = ops.prepend_name_scope(node.name, export_scope) ops_to_export.add(graph.get_operation_by_name(scoped_op_name)) graph_debug_info = create_graph_debug_info_def(ops_to_export) graph_io.write_graph( graph_debug_info, os.path.dirname(debug_filename), os.path.basename(debug_filename), as_text=as_text) return scoped_meta_graph_def, var_list