Esempio n. 1
0
 def list_dependencies(self, obj):
     """Overrides a parent method to include `add_object` objects."""
     extra_dependencies = self._extra_dependencies.get(obj, {})
     used_names = set()
     for name, dep in super(_AugmentedGraphView,
                            self).list_dependencies(obj):
         used_names.add(name)
         if name in extra_dependencies:
             yield base.CheckpointableReference(name,
                                                extra_dependencies[name])
         else:
             yield base.CheckpointableReference(name, dep)
     for name, dep in extra_dependencies.items():
         if name in used_names:
             continue
         yield base.CheckpointableReference(name, dep)
Esempio n. 2
0
 def _checkpoint_dependencies(self):
   """From Checkpointable. Gather graph-specific non-slot variables to save."""
   current_graph_non_slot_variables = []
   current_graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
   for (name, _), variable_object in sorted(self._non_slot_dict.items(),
                                            # Avoid comparing graphs
                                            key=lambda item: item[0][0]):
     if variable_object._graph_key == current_graph_key:  # pylint: disable=protected-access
       current_graph_non_slot_variables.append(
           checkpointable.CheckpointableReference(
               name=name, ref=variable_object))
   return (super(Optimizer, self)._checkpoint_dependencies
           + current_graph_non_slot_variables)
Esempio n. 3
0
    def _lookup_dependency(self, name):
        """Create placeholder NumPy arrays for to-be-restored attributes.

    Typically `_lookup_dependency` is used to check by name whether a dependency
    exists. We cheat slightly by creating a checkpointable object for `name` if
    we don't already have one, giving us attribute re-creation behavior when
    loading a checkpoint.

    Args:
      name: The name of the dependency being checked.
    Returns:
      An existing dependency if one exists, or a new `_NumpyWrapper` placeholder
      dependency (which will generally be restored immediately).
    """
        value = super(NumpyState, self)._lookup_dependency(name)
        if value is None:
            value = _NumpyWrapper(numpy.array([]))
            new_reference = base.CheckpointableReference(name=name, ref=value)
            self._unconditional_checkpoint_dependencies.append(new_reference)
            self._unconditional_dependency_names[name] = value
            super(NumpyState, self).__setattr__(name, value)
        return value
Esempio n. 4
0
 def _breadth_first_traversal(self):
   """Find shortest paths to all dependencies of self.root."""
   bfs_sorted = []
   to_visit = collections.deque([self.root])
   path_to_root = object_identity.ObjectIdentityDictionary()
   path_to_root[self.root] = ()
   while to_visit:
     current_checkpointable = to_visit.popleft()
     if isinstance(current_checkpointable, tracking.NotCheckpointable):
       raise NotImplementedError(
           ("The object %s does not support object-based saving. File a "
            "feature request if this limitation bothers you. In the meantime, "
            "you can remove the dependency on this object and save everything "
            "else.")
           % (current_checkpointable,))
     bfs_sorted.append(current_checkpointable)
     for name, dependency in self.list_dependencies(current_checkpointable):
       if dependency not in path_to_root:
         path_to_root[dependency] = (
             path_to_root[current_checkpointable] + (
                 base.CheckpointableReference(name, dependency),))
         to_visit.append(dependency)
   return bfs_sorted, path_to_root