def __init__(self, mesh: layout.Mesh, root=None, **kwargs): super(DTensorCheckpoint, self).__init__(root=root, **kwargs) self._mesh = mesh saver_root = self attached_dependencies = None self._save_counter = None # Created lazily for restore-on-create. self._save_assign_op = None if root: util._assert_trackable(root, "root") saver_root = root attached_dependencies = [] # All keyword arguments (including root itself) are set as children # of root. kwargs["root"] = root root._maybe_initialize_trackable() self._save_counter = data_structures.NoDependency( root._lookup_dependency("save_counter")) self._root = data_structures.NoDependency(root) for k, v in sorted(kwargs.items(), key=lambda item: item[0]): setattr(self, k, v) # Call getattr instead of directly using v because setattr converts # v to a Trackable data structure when v is a list/dict/tuple. converted_v = getattr(self, k) util._assert_trackable(converted_v, k) if root: # Make sure that root doesn't already have dependencies with these names attached_dependencies = attached_dependencies or [] child = root._lookup_dependency(k) if child is None: attached_dependencies.append( base.TrackableReference(k, converted_v)) elif child != converted_v: raise ValueError( "Cannot create a Checkpoint with keyword argument {name} if " "root.{name} already exists.".format(name=k)) # DTensor Change: # Override the parents saver with DTrackableSaver with _SingleDeviceSaver. self._saver = DTrackableSaver( mesh, graph_view_lib.ObjectGraphView( weakref.ref(saver_root), attached_dependencies=attached_dependencies))
def _descendants_with_paths(self): """Returns a list of all nodes and its paths from self.root using a breadth first traversal.""" bfs_sorted = [] to_visit = collections.deque([self.root]) node_paths = object_identity.ObjectIdentityDictionary() node_paths[self.root] = () while to_visit: current_trackable = to_visit.popleft() bfs_sorted.append(current_trackable) for name, dependency in self.children(current_trackable).items(): if dependency not in node_paths: node_paths[dependency] = ( node_paths[current_trackable] + (base.TrackableReference(name, dependency), )) to_visit.append(dependency) return bfs_sorted, node_paths
def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" bfs_sorted = [] to_visit = collections.deque([self.root]) node_paths = object_identity.ObjectIdentityDictionary() node_paths[self.root] = () while to_visit: current_trackable = to_visit.popleft() bfs_sorted.append(current_trackable) for name, dependency in self.list_children(current_trackable): if dependency not in node_paths: node_paths[dependency] = ( node_paths[current_trackable] + (base.TrackableReference(name, dependency), )) to_visit.append(dependency) return bfs_sorted, node_paths
def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): """Returns list of all child trackables attached to obj. Args: obj: A `Trackable` object. save_type: A string, can be 'savedmodel' or 'checkpoint'. **kwargs: kwargs to use when retrieving the object's children. Returns: List of all children attached to the object. """ children = [] for name, ref in super(ObjectGraphView, self).children(obj, save_type, **kwargs).items(): children.append(base.TrackableReference(name, ref)) # GraphView objects may define children of the root object that are not # actually attached, e.g. a Checkpoint object's save_counter. if obj is self.root and self._attached_dependencies: children.extend(self._attached_dependencies) return children
def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): """Returns all child trackables attached to obj. Args: obj: A `Trackable` object. save_type: A string, can be 'savedmodel' or 'checkpoint'. **kwargs: kwargs to use when retrieving the object's children. Returns: List of all children attached to the object. """ # pylint: disable=protected-access obj._maybe_initialize_trackable() children = [] for name, ref in obj._trackable_children(save_type, **kwargs).items(): ref = converter.convert_to_trackable(ref, parent=obj) children.append(base.TrackableReference(name, ref)) # pylint: enable=protected-access # GraphView objects may define children of the root object that are not # actually attached, e.g. a Checkpoint object's save_counter. if obj is self.root and self._attached_dependencies: children.extend(self._attached_dependencies) return children