def _restore_checkpoint(self):
    """Load state from checkpoint into the deserialized objects."""
    variables_path = saved_model_utils.get_variables_path(self._export_dir)
    # TODO(andresp): Clean use of private methods of TrackableSaver.
    # pylint: disable=protected-access
    saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
    with ops.device("CPU"):
      saver._file_prefix_placeholder = constant_op.constant(variables_path)
    load_status = saver.restore(variables_path)
    load_status.assert_existing_objects_matched()
    checkpoint = load_status._checkpoint

    # When running in eager mode, the `restore` call above has already run and
    # restored the state of trackables, call `position.restore_ops()` will
    # return an empty list as there is nothing left to do. In graph mode, that
    # will return the list of ops that must run to restore the object on that
    # position. We have to wire them in the initializers of the objects so that
    # they get initialized properly when using common practices (e.g. the ones
    # used by ManagedSession) without further user action.
    for object_id, obj in dict(checkpoint.object_by_proto_id).items():
      position = base.CheckpointPosition(checkpoint=checkpoint,
                                         proto_id=object_id)
      restore_ops = position.restore_ops()
      if restore_ops:
        if resource_variable_ops.is_resource_variable(obj):
          obj._initializer_op = restore_ops
        else:
          raise NotImplementedError(
              ("Missing functionality to restore state of object "
               "%r from the checkpoint." % obj))
Beispiel #2
0
    def _restore_checkpoint(self):
        """Load state from checkpoint into the deserialized objects."""
        variables_path = saved_model_utils.get_variables_path(self._export_dir)
        # TODO(b/205010730): Clean use of private methods of TrackableSaver.
        # pylint: disable=protected-access
        saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
        with ops.device("CPU"):
            saver._file_prefix_placeholder = constant_op.constant(
                variables_path)
        if self._save_options.allow_partial_checkpoint:
            load_status = saver.restore(
                variables_path, self._checkpoint_options).expect_partial()
            load_status.assert_nontrivial_match()
        else:
            load_status = saver.restore(variables_path,
                                        self._checkpoint_options)
            load_status.assert_existing_objects_matched()
        checkpoint = load_status._checkpoint

        if not context.executing_eagerly():
            # When running in eager mode, the `restore` call above has already run and
            # restored the state of trackables, and calling `position.restore_ops()`
            # would re-run the restore. In graph mode, that will return a cached list
            # of ops that must run to restore the object on that position. We have to
            # wire them in the initializers of the objects so that they get
            # initialized properly when using common practices (e.g. the ones used by
            # ManagedSession) without further user action.
            for object_id, obj in dict(checkpoint.object_by_proto_id).items():
                position = base.CheckpointPosition(checkpoint=checkpoint,
                                                   proto_id=object_id)
                registered_saver = position.get_registered_saver_name()
                if registered_saver:
                    raise NotImplementedError(
                        "Loading a SavedModel that uses registered checkpoint saver is "
                        f"not supported in graph mode. The loaded object {obj} uses the "
                        f"saver registered with the name {registered_saver}.")

                restore_ops = position.restore_ops()
                if restore_ops:
                    if resource_variable_ops.is_resource_variable(obj):
                        if len(restore_ops) == 1:
                            obj._initializer_op = restore_ops[0]
                        else:
                            obj._initializer_op = control_flow_ops.group(
                                *restore_ops)
                    elif isinstance(obj, lookup_ops.LookupInterface):
                        # We don't need to check for eager execution here, since this code
                        # path should only be taken if we are restoring in graph mode.
                        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                              restore_ops)
                    else:
                        raise NotImplementedError(
                            f"Unable to restore state of object {obj} from the checkpoint."
                        )
Beispiel #3
0
    def _restore_checkpoint(self):
        """Load state from checkpoint into the deserialized objects."""
        variables_path = saved_model_utils.get_variables_path(self._export_dir)
        # TODO(andresp): Clean use of private methods of TrackableSaver.
        # pylint: disable=protected-access
        saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
        with ops.device("CPU"):
            saver._file_prefix_placeholder = constant_op.constant(
                variables_path)
        if self._expect_partial_checkpoint:
            load_status = saver.restore(
                variables_path, self._checkpoint_options).expect_partial()
        else:
            load_status = saver.restore(variables_path,
                                        self._checkpoint_options)
        load_status.assert_existing_objects_matched()
        checkpoint = load_status._checkpoint

        # When running in eager mode, the `restore` call above has already run and
        # restored the state of trackables, call `position.restore_ops()` will
        # return an empty list as there is nothing left to do. In graph mode, that
        # will return the list of ops that must run to restore the object on that
        # position. We have to wire them in the initializers of the objects so that
        # they get initialized properly when using common practices (e.g. the ones
        # used by ManagedSession) without further user action.
        for object_id, obj in dict(checkpoint.object_by_proto_id).items():
            position = base.CheckpointPosition(checkpoint=checkpoint,
                                               proto_id=object_id)
            restore_ops = position.restore_ops()
            if restore_ops:
                if resource_variable_ops.is_resource_variable(obj):
                    if len(restore_ops) == 1:
                        obj._initializer_op = restore_ops[0]
                    else:
                        obj._initializer_op = control_flow_ops.group(
                            *restore_ops)
                elif isinstance(obj, lookup_ops.LookupInterface):
                    # We don't need to check for eager execution here, since this code
                    # path should only be taken if we are restoring in graph mode.
                    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                          restore_ops)
                else:
                    raise NotImplementedError(
                        ("Missing functionality to restore state of object "
                         "%r from the checkpoint." % obj))
    def restore(self, save_path, options=None):
        """Restore a training checkpoint with host mesh placement."""
        options = options or checkpoint_options.CheckpointOptions()
        if save_path is None:
            return util.InitializationOnlyStatus(self._graph_view, ops.uid())
        reader = py_checkpoint_reader.NewCheckpointReader(save_path)
        graph_building = not context.executing_eagerly()
        if graph_building:
            dtype_map = None
        else:
            dtype_map = reader.get_variable_to_dtype_map()
        try:
            object_graph_string = reader.get_tensor(
                base.OBJECT_GRAPH_PROTO_KEY)
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try the
            # name-based compatibility mode.
            restore_coordinator = util._NameBasedRestoreCoordinator(  # pylint: disable=protected-access
                save_path=save_path,
                dtype_map=dtype_map)
            if not graph_building:
                for existing_trackable in self._graph_view.list_objects():
                    # pylint: disable=protected-access
                    existing_trackable._maybe_initialize_trackable()
                    existing_trackable._name_based_restores.add(
                        restore_coordinator)
                    existing_trackable._name_based_attribute_restore(
                        restore_coordinator)
                    # pylint: enable=protected-access
            return util.NameBasedSaverStatus(restore_coordinator,
                                             graph_view=self._graph_view)

        if graph_building:
            if self._file_prefix_placeholder is None:
                # DTensor change: provide a hint for mesh broadcasting to put the input
                # onto the host mesh.
                self._file_prefix_placeholder = api.pack(
                    [constant_op.constant("model")] *
                    self._mesh.num_local_devices(),
                    layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            # DTensor change: provide a hint for mesh broadcasting to put the input
            # onto the host mesh.
            file_prefix_tensor = api.pack([constant_op.constant(save_path)] *
                                          self._mesh.num_local_devices(),
                                          layout.Layout.replicated(
                                              self._mesh.host_mesh(), rank=0))
            file_prefix_feed_dict = None
        object_graph_proto = (
            trackable_object_graph_pb2.TrackableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        # DTensor Change: Hook the proper DSaver in restore.
        checkpoint = _DCheckpointRestoreCoordinator(
            mesh=self._mesh,
            object_graph_proto=object_graph_proto,
            save_path=save_path,
            save_path_tensor=file_prefix_tensor,
            reader=reader,
            restore_op_cache=self._restore_op_cache,
            graph_view=self._graph_view,
            options=options,
            saveables_cache=self._saveables_cache)
        base.CheckpointPosition(checkpoint=checkpoint,
                                proto_id=0).restore(self._graph_view.root)

        # Attached dependencies are not attached to the root, so should be restored
        # separately.
        if self._graph_view.attached_dependencies:
            for ref in self._graph_view.attached_dependencies:
                if ref.name == "root":
                    # Root dependency is automatically added to attached dependencies --
                    # this can be ignored since it maps back to the root object.
                    continue
                proto_id = None
                # Find proto ID of attached dependency (if it is in the proto).
                for proto_ref in object_graph_proto.nodes[0].children:
                    if proto_ref.local_name == ref.name:
                        proto_id = proto_ref.node_id
                        break

                if proto_id in checkpoint.object_by_proto_id:
                    # Object has already been restored. This can happen when there's an
                    # indirect connection from the attached object to the root.
                    continue

                base.CheckpointPosition(checkpoint=checkpoint,
                                        proto_id=proto_id).restore(ref.ref)

        load_status = util.CheckpointLoadStatus(
            checkpoint,
            graph_view=self._graph_view,
            feed_dict=file_prefix_feed_dict)
        return load_status