def list_dependencies(self, obj): """Overrides a parent method to include `add_object` objects.""" extra_dependencies = self.list_extra_dependencies(obj) extra_dependencies.update(self._extra_dependencies.get(obj, {})) used_names = set() for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj): used_names.add(name) if name in extra_dependencies: # Extra dependencies (except for `.signatures`, which is always added # when saving) should not have naming conflicts with dependencies # defined by the user. if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME: raise ValueError( "Error when exporting object {} of with identifier={}. The object" " has an attribute named {}, which is reserved. List of all " "reserved attributes: {}".format( obj, obj._object_identifier, # pylint: disable=protected-access name, extra_dependencies.keys())) yield base.TrackableReference(name, extra_dependencies[name]) else: yield base.TrackableReference(name, dep) for name, dep in extra_dependencies.items(): if name in used_names: continue yield base.TrackableReference(name, dep)
def list_dependencies(self, obj): """Overrides a parent method to include `add_object` objects.""" extra_dependencies = self._extra_dependencies.get(obj, {}) used_names = set() for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj): used_names.add(name) if name in extra_dependencies: yield base.TrackableReference(name, extra_dependencies[name]) else: yield base.TrackableReference(name, dep) for name, dep in extra_dependencies.items(): if name in used_names: continue yield base.TrackableReference(name, dep)
def _checkpoint_dependencies(self): """From Trackable. Gather graph-specific weights to save.""" if context.executing_eagerly(): graph_key = None else: graph = ops.get_default_graph() graph_key = graph._graph_key # pylint: disable=protected-access weights = [] for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): if g == graph_key: weights.append(trackable.TrackableReference(name=name, ref=v)) return super(LossScale, self)._checkpoint_dependencies + weights
def _checkpoint_dependencies(self): """From Trackable. Gather graph-specific non-slot variables to save.""" current_graph_non_slot_variables = [] current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access for (name, _), variable_object in sorted(self._non_slot_dict.items(), # Avoid comparing graphs key=lambda item: item[0][0]): if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access current_graph_non_slot_variables.append( trackable.TrackableReference( name=name, ref=variable_object)) return (super(Optimizer, self)._checkpoint_dependencies + current_graph_non_slot_variables)
def __init__(self, mesh: layout.Mesh, root=None, **kwargs): super(DTensorCheckpoint, self).__init__(root=root, **kwargs) self._mesh = mesh saver_root = self attached_dependencies = None self._save_counter = None # Created lazily for restore-on-create. self._save_assign_op = None if root: util._assert_trackable(root, "root") saver_root = root attached_dependencies = [] # All keyword arguments (including root itself) are set as children # of root. kwargs["root"] = root root._maybe_initialize_trackable() self._save_counter = data_structures.NoDependency( root._lookup_dependency("save_counter")) self._root = data_structures.NoDependency(root) for k, v in sorted(kwargs.items(), key=lambda item: item[0]): setattr(self, k, v) # Call getattr instead of directly using v because setattr converts # v to a Trackable data structure when v is a list/dict/tuple. converted_v = getattr(self, k) util._assert_trackable(converted_v, k) if root: # Make sure that root doesn't already have dependencies with these names attached_dependencies = attached_dependencies or [] child = root._lookup_dependency(k) if child is None: attached_dependencies.append( base.TrackableReference(k, converted_v)) elif child != converted_v: raise ValueError( "Cannot create a Checkpoint with keyword argument {name} if " "root.{name} already exists.".format(name=k)) # DTensor Change: # Override the parents saver with DTrackableSaver with _SingleDeviceSaver. self._saver = DTrackableSaver( mesh, graph_view_lib.ObjectGraphView( weakref.ref(saver_root), attached_dependencies=attached_dependencies))
def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" bfs_sorted = [] to_visit = collections.deque([self.root]) node_paths = object_identity.ObjectIdentityDictionary() node_paths[self.root] = () while to_visit: current_trackable = to_visit.popleft() bfs_sorted.append(current_trackable) for name, dependency in self.list_children(current_trackable): if dependency not in node_paths: node_paths[dependency] = ( node_paths[current_trackable] + (base.TrackableReference(name, dependency), )) to_visit.append(dependency) return bfs_sorted, node_paths
def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" bfs_sorted = [] to_visit = collections.deque([self.root]) path_to_root = object_identity.ObjectIdentityDictionary() path_to_root[self.root] = () while to_visit: current_trackable = to_visit.popleft() if isinstance(current_trackable, tracking.NotTrackable): raise NotImplementedError(( "The object %s does not support object-based saving. File a " "feature request if this limitation bothers you. In the meantime, " "you can remove the dependency on this object and save everything " "else.") % (current_trackable, )) bfs_sorted.append(current_trackable) for name, dependency in self.list_dependencies(current_trackable): if dependency not in path_to_root: path_to_root[dependency] = ( path_to_root[current_trackable] + (base.TrackableReference(name, dependency), )) to_visit.append(dependency) return bfs_sorted, path_to_root
def _lookup_dependency(self, name): """Create placeholder NumPy arrays for to-be-restored attributes. Typically `_lookup_dependency` is used to check by name whether a dependency exists. We cheat slightly by creating a trackable object for `name` if we don't already have one, giving us attribute re-creation behavior when loading a checkpoint. Args: name: The name of the dependency being checked. Returns: An existing dependency if one exists, or a new `_NumpyWrapper` placeholder dependency (which will generally be restored immediately). """ value = super(NumpyState, self)._lookup_dependency(name) if value is None: value = _NumpyWrapper(numpy.array([])) new_reference = base.TrackableReference(name=name, ref=value) self._unconditional_checkpoint_dependencies.append(new_reference) self._unconditional_dependency_names[name] = value super(NumpyState, self).__setattr__(name, value) return value
def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): """Returns all child trackables attached to obj. Args: obj: A `Trackable` object. save_type: A string, can be 'savedmodel' or 'checkpoint'. **kwargs: kwargs to use when retrieving the object's children. Returns: List of all children attached to the object. """ # pylint: disable=protected-access obj._maybe_initialize_trackable() children = [ base.TrackableReference(name, ref) for name, ref in obj._trackable_children(save_type, **kwargs).items() ] # pylint: enable=protected-access # GraphView objects may define children of the root object that are not # actually attached, e.g. a Checkpoint object's save_counter. if obj is self.root and self._attached_dependencies: children.extend(self._attached_dependencies) return children