コード例 #1
0
    def _restore_checkpoint(self):
        """Load state from checkpoint into the deserialized objects."""
        variables_path = saved_model_utils.get_variables_path(self._export_dir)
        # TODO(b/205010730): Clean use of private methods of TrackableSaver.
        # pylint: disable=protected-access
        saver = checkpoint.TrackableSaver(
            graph_view.ObjectGraphView(self.get(0)))
        with ops.device("CPU"):
            saver._file_prefix_placeholder = constant_op.constant(
                variables_path)
        if self._save_options.allow_partial_checkpoint:
            load_status = saver.restore(
                variables_path, self._checkpoint_options).expect_partial()
            load_status.assert_nontrivial_match()
        else:
            load_status = saver.restore(variables_path,
                                        self._checkpoint_options)
            load_status.assert_existing_objects_matched()
        ckpt = load_status._checkpoint

        if not context.executing_eagerly():
            # When running in eager mode, the `restore` call above has already run and
            # restored the state of trackables, and calling `position.restore_ops()`
            # would re-run the restore. In graph mode, that will return a cached list
            # of ops that must run to restore the object on that position. We have to
            # wire them in the initializers of the objects so that they get
            # initialized properly when using common practices (e.g. the ones used by
            # ManagedSession) without further user action.
            for object_id, obj in dict(ckpt.object_by_proto_id).items():
                position = base.CheckpointPosition(checkpoint=ckpt,
                                                   proto_id=object_id)
                registered_saver = position.get_registered_saver_name()
                if registered_saver:
                    raise NotImplementedError(
                        "Loading a SavedModel that uses registered checkpoint saver is "
                        f"not supported in graph mode. The loaded object {obj} uses the "
                        f"saver registered with the name {registered_saver}.")

                restore_ops = position.restore_ops()
                if restore_ops:
                    if resource_variable_ops.is_resource_variable(obj):
                        if len(restore_ops) == 1:
                            obj._initializer_op = restore_ops[0]
                        else:
                            obj._initializer_op = control_flow_ops.group(
                                *restore_ops)
                    elif isinstance(obj, lookup_ops.LookupInterface):
                        # We don't need to check for eager execution here, since this code
                        # path should only be taken if we are restoring in graph mode.
                        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                              restore_ops)
                    else:
                        raise NotImplementedError(
                            f"Unable to restore state of object {obj} from the checkpoint."
                        )
コード例 #2
0
  def test_serialize_gathered_objects(self):
    root = autotrackable.AutoTrackable()
    root.v = variables.Variable(1.0)
    root.registered = TrackableWithRegisteredSaver()
    named_saveable_objects, _, _, registered_savers = (
        save_util_v1.serialize_gathered_objects(
            graph_view.ObjectGraphView(root)))

    self.assertLen(named_saveable_objects, 1)
    self.assertIs(named_saveable_objects[0].op, root.v)
    self.assertDictEqual(
        {"Custom.RegisteredSaver": {"registered": root.registered}},
        registered_savers)
コード例 #3
0
    def __init__(self, mesh: layout.Mesh, root=None, **kwargs):
        super(DTensorCheckpoint, self).__init__(root=root, **kwargs)
        self._mesh = mesh

        saver_root = self
        attached_dependencies = None
        self._save_counter = None  # Created lazily for restore-on-create.
        self._save_assign_op = None

        if root:
            util._assert_trackable(root, "root")
            saver_root = root
            attached_dependencies = []

            # All keyword arguments (including root itself) are set as children
            # of root.
            kwargs["root"] = root
            root._maybe_initialize_trackable()

            self._save_counter = data_structures.NoDependency(
                root._lookup_dependency("save_counter"))
            self._root = data_structures.NoDependency(root)

        for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
            setattr(self, k, v)

            # Call getattr instead of directly using v because setattr converts
            # v to a Trackable data structure when v is a list/dict/tuple.
            converted_v = getattr(self, k)
            util._assert_trackable(converted_v, k)

            if root:
                # Make sure that root doesn't already have dependencies with these names
                attached_dependencies = attached_dependencies or []
                child = root._lookup_dependency(k)
                if child is None:
                    attached_dependencies.append(
                        base.TrackableReference(k, converted_v))
                elif child != converted_v:
                    raise ValueError(
                        "Cannot create a Checkpoint with keyword argument {name} if "
                        "root.{name} already exists.".format(name=k))
        # DTensor Change:
        # Override the parents saver with DTrackableSaver with _SingleDeviceSaver.
        self._saver = DTrackableSaver(
            mesh,
            graph_view_lib.ObjectGraphView(
                weakref.ref(saver_root),
                attached_dependencies=attached_dependencies))
コード例 #4
0
  def test_serialize_gathered_objects_with_map(self):
    root = autotrackable.AutoTrackable()
    root.v = variables.Variable(1.0)
    root.registered = TrackableWithRegisteredSaver()

    copy_of_registered = TrackableWithRegisteredSaver()
    copy_of_v = variables.Variable(1.0)
    object_map = object_identity.ObjectIdentityDictionary()
    object_map[root.registered] = copy_of_registered
    object_map[root.v] = copy_of_v

    named_saveable_objects, _, _, registered_savers = (
        save_util_v1.serialize_gathered_objects(
            graph_view.ObjectGraphView(root), object_map))

    self.assertLen(named_saveable_objects, 1)
    self.assertIsNot(named_saveable_objects[0].op, root.v)
    self.assertIs(named_saveable_objects[0].op, copy_of_v)

    ret_value = registered_savers["Custom.RegisteredSaver"]["registered"]
    self.assertIsNot(root.registered, ret_value)
    self.assertIs(copy_of_registered, ret_value)
コード例 #5
0
def _get_var_list(model):
    """Returns list of all checkpointed saveable objects in the model."""
    var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
    return var_list