def list_dependencies(self, obj): """Overrides a parent method to include `add_object` objects.""" extra_dependencies = self._extra_dependencies.get(obj, {}) used_names = set() for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj): used_names.add(name) if name in extra_dependencies: yield base.CheckpointableReference(name, extra_dependencies[name]) else: yield base.CheckpointableReference(name, dep) for name, dep in extra_dependencies.items(): if name in used_names: continue yield base.CheckpointableReference(name, dep)
def _checkpoint_dependencies(self): """From Checkpointable. Gather graph-specific non-slot variables to save.""" current_graph_non_slot_variables = [] current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access for (name, _), variable_object in sorted(self._non_slot_dict.items(), # Avoid comparing graphs key=lambda item: item[0][0]): if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access current_graph_non_slot_variables.append( checkpointable.CheckpointableReference( name=name, ref=variable_object)) return (super(Optimizer, self)._checkpoint_dependencies + current_graph_non_slot_variables)
def _lookup_dependency(self, name): """Create placeholder NumPy arrays for to-be-restored attributes. Typically `_lookup_dependency` is used to check by name whether a dependency exists. We cheat slightly by creating a checkpointable object for `name` if we don't already have one, giving us attribute re-creation behavior when loading a checkpoint. Args: name: The name of the dependency being checked. Returns: An existing dependency if one exists, or a new `_NumpyWrapper` placeholder dependency (which will generally be restored immediately). """ value = super(NumpyState, self)._lookup_dependency(name) if value is None: value = _NumpyWrapper(numpy.array([])) new_reference = base.CheckpointableReference(name=name, ref=value) self._unconditional_checkpoint_dependencies.append(new_reference) self._unconditional_dependency_names[name] = value super(NumpyState, self).__setattr__(name, value) return value
def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" bfs_sorted = [] to_visit = collections.deque([self.root]) path_to_root = object_identity.ObjectIdentityDictionary() path_to_root[self.root] = () while to_visit: current_checkpointable = to_visit.popleft() if isinstance(current_checkpointable, tracking.NotCheckpointable): raise NotImplementedError( ("The object %s does not support object-based saving. File a " "feature request if this limitation bothers you. In the meantime, " "you can remove the dependency on this object and save everything " "else.") % (current_checkpointable,)) bfs_sorted.append(current_checkpointable) for name, dependency in self.list_dependencies(current_checkpointable): if dependency not in path_to_root: path_to_root[dependency] = ( path_to_root[current_checkpointable] + ( base.CheckpointableReference(name, dependency),)) to_visit.append(dependency) return bfs_sorted, path_to_root