def restore(self, save_path): """Restore a training checkpoint. Restores `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to the `root_checkpointable` passed to the constructor after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are added to the graph but not run. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python saver = Saver(root) saver.restore(path).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. When graph building, `assert_consumed()` indicates that all of the restore ops which will be created for this checkpoint have been created. They can be run via the `run_restore_ops()` function of the status object: ```python saver.restore(path).assert_consumed().run_restore_ops() ``` If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph. Name-based `tf.train.Saver` checkpoints can be loaded using this method. There is no deferred loading, and names are used to match variables. No restore ops are created/run until `run_restore_ops()` or `initialize_or_restore()` are called on the returned status object, even when executing eagerly. Re-encode name-based checkpoints using this object-based `Saver.save` as soon as possible. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.train.Saver`, names are used to match variables. Returns: A load status object, which can be used to make assertions about the status of checkpoint restoration and run initialization/restore ops (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if `save_path` is `None`). If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` object is returned which runs restore ops from a name-based saver. """ if save_path is None: return InitializationOnlyStatus(self._root_checkpointable, ops.uid()) in_graph_mode = not context.executing_eagerly() if in_graph_mode: if self._file_prefix_placeholder is None: with ops.device("/cpu:0"): self._file_prefix_placeholder = constant_op.constant("model") file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: with ops.device("/cpu:0"): file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None reader = pywrap_tensorflow.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor( checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try again with # name-based saving. return NameBasedSaverStatus(self, save_path) object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) if in_graph_mode and object_graph_proto == self._last_restore_object_graph: checkpoint = self._last_restore_checkpoint else: if in_graph_mode: dtype_map = None else: dtype_map = reader.get_variable_to_dtype_map() checkpoint = _CheckpointRestoreCoordinator( object_graph_proto=object_graph_proto, save_path=file_prefix_tensor, dtype_map=dtype_map) if in_graph_mode: if self._last_restore_object_graph is not None: raise NotImplementedError( "Using a single Saver to restore different object graphs is not " "currently supported when graph building. Use a different Saver " "for each object graph (restore ops will be duplicated), or " "file a feature request if this limitation bothers you.") self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto checkpointable_lib._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus( checkpoint, root_checkpointable=self._root_checkpointable, feed_dict=file_prefix_feed_dict) return load_status
def restore(save_path, root_checkpointable, session=None): """Restore a training checkpoint. Restores the values of variables created with `Checkpointable._add_variable` in `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to `root_checkpointable` after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are added to the graph but not run. A session is required to retrieve checkpoint metadata. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python restore(path, root).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. When graph building, `assert_consumed()` indicates that all of the restore ops which will be created for this checkpoint have been created. They are available in the `restore_ops` property of the status object: ```python session.run(restore(path, root).assert_consumed().restore_ops) ``` If the checkpoint has not been consumed completely, then the list of `restore_ops` will grow as more objects are added to the dependency graph. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), does nothing. root_checkpointable: The root of the object graph to restore. Variables to restore need not have been created yet, but all dependencies on other `Checkpointable` objects should already be declared. Objects in the dependency graph are matched to objects in the checkpointed graph, and matching objects have their variables restored (or the checkpointed values saved for eventual restoration when the variable is created). session: The session to retrieve metadata with. Ignored when executing eagerly. If not provided when graph building, the default session is used. Returns: A `CheckpointLoadStatus` object, which can be used to make assertions about the status of checkpoint restoration and fetch restore ops. """ if save_path is None: return if context.in_graph_mode(): if session is None: session = ops.get_default_session() else: session = None object_graph_string, = io_ops.restore_v2( prefix=save_path, tensor_names=[_OBJECT_GRAPH_PROTO_KEY], shape_and_slices=[""], dtypes=[dtypes.string], name="object_graph_proto_read") if session is not None: object_graph_string = session.run(object_graph_string) else: object_graph_string = object_graph_string.numpy() object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) checkpoint = core_checkpointable._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, save_path=save_path) core_checkpointable._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(root_checkpointable) load_status = CheckpointLoadStatus(checkpoint) return load_status
def restore(self, save_path, session=None): """Restore a training checkpoint. Restores `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to the `root_checkpointable` passed to the constructor after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are added to the graph but not run. A session is required to retrieve checkpoint metadata. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python saver = Saver(root) saver.restore(path).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. When graph building, `assert_consumed()` indicates that all of the restore ops which will be created for this checkpoint have been created. They can be run via the `run_restore_ops()` function of the status object: ```python saver.restore(path).assert_consumed().run_restore_ops() ``` If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), does nothing. session: The session to retrieve metadata with. Ignored when executing eagerly. If not provided when graph building, the default session is used. Returns: A `CheckpointLoadStatus` object, which can be used to make assertions about the status of checkpoint restoration and run restore ops. """ if save_path is None: return in_graph_mode = context.in_graph_mode() if in_graph_mode: if session is None: session = ops.get_default_session() file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: session = None file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None if not in_graph_mode or self._object_graph_restore_tensor is None: object_graph_string, = io_ops.restore_v2( prefix=file_prefix_tensor, tensor_names=[_OBJECT_GRAPH_PROTO_KEY], shape_and_slices=[""], dtypes=[dtypes.string], name="object_graph_proto_read") if in_graph_mode: self._object_graph_restore_tensor = object_graph_string if in_graph_mode: object_graph_string = session.run( self._object_graph_restore_tensor, feed_dict=file_prefix_feed_dict) else: object_graph_string = object_graph_string.numpy() object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) if in_graph_mode and object_graph_proto == self._last_restore_object_graph: checkpoint = self._last_restore_checkpoint else: if in_graph_mode: dtype_map = None else: reader = pywrap_tensorflow.NewCheckpointReader(save_path) dtype_map = reader.get_variable_to_dtype_map() checkpoint = core_checkpointable_utils._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, save_path=file_prefix_tensor, dtype_map=dtype_map) if in_graph_mode: if self._last_restore_object_graph is not None: raise NotImplementedError( "Using a single Saver to restore different object graphs is not " "currently supported when graph building. Use a different Saver " "for each object graph (restore ops will be duplicated), or " "file a feature request if this limitation bothers you.") self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto core_checkpointable._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus( checkpoint, feed_dict=file_prefix_feed_dict) return load_status
def restore(save_path, root_checkpointable, session=None): """Restore a training checkpoint. Restores the values of variables created with `Checkpointable._add_variable` in `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to `root_checkpointable` after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are executed in the default session if `session` is `None`. Variable initializers read checkpointed values. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python restore(path, root).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), does nothing. root_checkpointable: The root of the object graph to restore. Variables to restore need not have been created yet, but all dependencies on other Checkpointable objects should already be declared. Objects in the dependency graph are matched to objects in the checkpointed graph, and matching objects have their variables restored (or the checkpointed values saved for eventual restoration when the variable is created). session: The session to evaluate assignment ops in. Ignored when executing eagerly. If not provided when graph building, the default session is used. Returns: A CheckpointLoadStatus object, which can be used to make assertions about the status of checkpoint restoration. """ if save_path is None: return if context.in_graph_mode(): if session is None: session = ops.get_default_session() else: session = None object_graph_string, = io_ops.restore_v2( prefix=save_path, tensor_names=[_OBJECT_GRAPH_PROTO_KEY], shape_and_slices=[""], dtypes=[dtypes.string], name="object_graph_proto_read") if session is not None: object_graph_string = session.run(object_graph_string) else: object_graph_string = object_graph_string.numpy() object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) checkpoint = core_checkpointable._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, save_path=save_path, session=session) core_checkpointable._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(root_checkpointable) return CheckpointLoadStatus(checkpoint)