Ejemplo n.º 1
0
    def __init__(self, mesh: layout.Mesh, root=None, **kwargs):
        super(DTensorCheckpoint, self).__init__(root=root, **kwargs)
        self._mesh = mesh

        saver_root = self
        attached_dependencies = None
        self._save_counter = None  # Created lazily for restore-on-create.
        self._save_assign_op = None

        if root:
            util._assert_trackable(root, "root")
            saver_root = root
            attached_dependencies = []

            # All keyword arguments (including root itself) are set as children
            # of root.
            kwargs["root"] = root
            root._maybe_initialize_trackable()

            self._save_counter = data_structures.NoDependency(
                root._lookup_dependency("save_counter"))
            self._root = data_structures.NoDependency(root)

        for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
            setattr(self, k, v)

            # Call getattr instead of directly using v because setattr converts
            # v to a Trackable data structure when v is a list/dict/tuple.
            converted_v = getattr(self, k)
            util._assert_trackable(converted_v, k)

            if root:
                # Make sure that root doesn't already have dependencies with these names
                attached_dependencies = attached_dependencies or []
                child = root._lookup_dependency(k)
                if child is None:
                    attached_dependencies.append(
                        base.TrackableReference(k, converted_v))
                elif child != converted_v:
                    raise ValueError(
                        "Cannot create a Checkpoint with keyword argument {name} if "
                        "root.{name} already exists.".format(name=k))
        # DTensor Change:
        # Override the parents saver with DTrackableSaver with _SingleDeviceSaver.
        self._saver = DTrackableSaver(
            mesh,
            graph_view_lib.ObjectGraphView(
                weakref.ref(saver_root),
                attached_dependencies=attached_dependencies))
Ejemplo n.º 2
0
 def _descendants_with_paths(self):
     """Returns a list of all nodes and its paths from self.root using a breadth first traversal."""
     bfs_sorted = []
     to_visit = collections.deque([self.root])
     node_paths = object_identity.ObjectIdentityDictionary()
     node_paths[self.root] = ()
     while to_visit:
         current_trackable = to_visit.popleft()
         bfs_sorted.append(current_trackable)
         for name, dependency in self.children(current_trackable).items():
             if dependency not in node_paths:
                 node_paths[dependency] = (
                     node_paths[current_trackable] +
                     (base.TrackableReference(name, dependency), ))
                 to_visit.append(dependency)
     return bfs_sorted, node_paths
Ejemplo n.º 3
0
 def _breadth_first_traversal(self):
     """Find shortest paths to all dependencies of self.root."""
     bfs_sorted = []
     to_visit = collections.deque([self.root])
     node_paths = object_identity.ObjectIdentityDictionary()
     node_paths[self.root] = ()
     while to_visit:
         current_trackable = to_visit.popleft()
         bfs_sorted.append(current_trackable)
         for name, dependency in self.list_children(current_trackable):
             if dependency not in node_paths:
                 node_paths[dependency] = (
                     node_paths[current_trackable] +
                     (base.TrackableReference(name, dependency), ))
                 to_visit.append(dependency)
     return bfs_sorted, node_paths
Ejemplo n.º 4
0
  def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
    """Returns list of all child trackables attached to obj.

    Args:
      obj: A `Trackable` object.
      save_type: A string, can be 'savedmodel' or 'checkpoint'.
      **kwargs: kwargs to use when retrieving the object's children.

    Returns:
      List of all children attached to the object.
    """
    children = []
    for name, ref in super(ObjectGraphView,
                           self).children(obj, save_type, **kwargs).items():
      children.append(base.TrackableReference(name, ref))

    # GraphView objects may define children of the root object that are not
    # actually attached, e.g. a Checkpoint object's save_counter.
    if obj is self.root and self._attached_dependencies:
      children.extend(self._attached_dependencies)
    return children
Ejemplo n.º 5
0
    def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
        """Returns all child trackables attached to obj.

    Args:
      obj: A `Trackable` object.
      save_type: A string, can be 'savedmodel' or 'checkpoint'.
      **kwargs: kwargs to use when retrieving the object's children.

    Returns:
      List of all children attached to the object.
    """
        # pylint: disable=protected-access
        obj._maybe_initialize_trackable()
        children = []
        for name, ref in obj._trackable_children(save_type, **kwargs).items():
            ref = converter.convert_to_trackable(ref, parent=obj)
            children.append(base.TrackableReference(name, ref))
        # pylint: enable=protected-access

        # GraphView objects may define children of the root object that are not
        # actually attached, e.g. a Checkpoint object's save_counter.
        if obj is self.root and self._attached_dependencies:
            children.extend(self._attached_dependencies)
        return children