def _restore_checkpoint(self): """Load state from checkpoint into the deserialized objects.""" variables_path = saved_model_utils.get_variables_path(self._export_dir) # TODO(b/205010730): Clean use of private methods of TrackableSaver. # pylint: disable=protected-access saver = checkpoint.TrackableSaver( graph_view.ObjectGraphView(self.get(0))) with ops.device("CPU"): saver._file_prefix_placeholder = constant_op.constant( variables_path) if self._save_options.allow_partial_checkpoint: load_status = saver.restore( variables_path, self._checkpoint_options).expect_partial() load_status.assert_nontrivial_match() else: load_status = saver.restore(variables_path, self._checkpoint_options) load_status.assert_existing_objects_matched() ckpt = load_status._checkpoint if not context.executing_eagerly(): # When running in eager mode, the `restore` call above has already run and # restored the state of trackables, and calling `position.restore_ops()` # would re-run the restore. In graph mode, that will return a cached list # of ops that must run to restore the object on that position. We have to # wire them in the initializers of the objects so that they get # initialized properly when using common practices (e.g. the ones used by # ManagedSession) without further user action. for object_id, obj in dict(ckpt.object_by_proto_id).items(): position = base.CheckpointPosition(checkpoint=ckpt, proto_id=object_id) registered_saver = position.get_registered_saver_name() if registered_saver: raise NotImplementedError( "Loading a SavedModel that uses registered checkpoint saver is " f"not supported in graph mode. The loaded object {obj} uses the " f"saver registered with the name {registered_saver}.") restore_ops = position.restore_ops() if restore_ops: if resource_variable_ops.is_resource_variable(obj): if len(restore_ops) == 1: obj._initializer_op = restore_ops[0] else: obj._initializer_op = control_flow_ops.group( *restore_ops) elif isinstance(obj, lookup_ops.LookupInterface): # We don't need to check for eager execution here, since this code # path should only be taken if we are restoring in graph mode. ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) else: raise NotImplementedError( f"Unable to restore state of object {obj} from the checkpoint." )
def test_serialize_gathered_objects(self): root = autotrackable.AutoTrackable() root.v = variables.Variable(1.0) root.registered = TrackableWithRegisteredSaver() named_saveable_objects, _, _, registered_savers = ( save_util_v1.serialize_gathered_objects( graph_view.ObjectGraphView(root))) self.assertLen(named_saveable_objects, 1) self.assertIs(named_saveable_objects[0].op, root.v) self.assertDictEqual( {"Custom.RegisteredSaver": {"registered": root.registered}}, registered_savers)
def __init__(self, mesh: layout.Mesh, root=None, **kwargs): super(DTensorCheckpoint, self).__init__(root=root, **kwargs) self._mesh = mesh saver_root = self attached_dependencies = None self._save_counter = None # Created lazily for restore-on-create. self._save_assign_op = None if root: util._assert_trackable(root, "root") saver_root = root attached_dependencies = [] # All keyword arguments (including root itself) are set as children # of root. kwargs["root"] = root root._maybe_initialize_trackable() self._save_counter = data_structures.NoDependency( root._lookup_dependency("save_counter")) self._root = data_structures.NoDependency(root) for k, v in sorted(kwargs.items(), key=lambda item: item[0]): setattr(self, k, v) # Call getattr instead of directly using v because setattr converts # v to a Trackable data structure when v is a list/dict/tuple. converted_v = getattr(self, k) util._assert_trackable(converted_v, k) if root: # Make sure that root doesn't already have dependencies with these names attached_dependencies = attached_dependencies or [] child = root._lookup_dependency(k) if child is None: attached_dependencies.append( base.TrackableReference(k, converted_v)) elif child != converted_v: raise ValueError( "Cannot create a Checkpoint with keyword argument {name} if " "root.{name} already exists.".format(name=k)) # DTensor Change: # Override the parents saver with DTrackableSaver with _SingleDeviceSaver. self._saver = DTrackableSaver( mesh, graph_view_lib.ObjectGraphView( weakref.ref(saver_root), attached_dependencies=attached_dependencies))
def test_serialize_gathered_objects_with_map(self): root = autotrackable.AutoTrackable() root.v = variables.Variable(1.0) root.registered = TrackableWithRegisteredSaver() copy_of_registered = TrackableWithRegisteredSaver() copy_of_v = variables.Variable(1.0) object_map = object_identity.ObjectIdentityDictionary() object_map[root.registered] = copy_of_registered object_map[root.v] = copy_of_v named_saveable_objects, _, _, registered_savers = ( save_util_v1.serialize_gathered_objects( graph_view.ObjectGraphView(root), object_map)) self.assertLen(named_saveable_objects, 1) self.assertIsNot(named_saveable_objects[0].op, root.v) self.assertIs(named_saveable_objects[0].op, copy_of_v) ret_value = registered_savers["Custom.RegisteredSaver"]["registered"] self.assertIsNot(root.registered, ret_value) self.assertIs(copy_of_registered, ret_value)
def _get_var_list(model): """Returns list of all checkpointed saveable objects in the model.""" var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph() return var_list