def test_resource_variable_use_localhost(self):
    v1 = resource_variable_ops.ResourceVariable(2.)
    self.evaluate(v1.initializer)
    saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v1, "x"))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")
    self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
    self.evaluate(v1.assign(1.))
    self.evaluate(saver.restore(prefix, self.local_options))
    self.assertEqual(2., self.evaluate(v1))

    v2 = resource_variable_ops.ResourceVariable(3.)
    self.evaluate(v2.initializer)
    second_saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v2, "x"))
    self.evaluate(second_saver.restore(prefix, self.local_options))
    self.assertEqual(2., self.evaluate(v2))

    # In graph mode, verify that the save and restore ops were set to run on
    # localhost.
    if not context.executing_eagerly():
      for op in ops.get_default_graph().get_operations():
        if op.type in ("SaveV2", "RestoreV2"):
          self.assertEqual(LOCALHOST, op.device)
  def test_to_proto(self):
    v1 = resource_variable_ops.ResourceVariable(2.)
    saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v1, "x"))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")

    proto_accumulator = []
    wrapped = wrap_function.wrap_function(
        lambda: proto_accumulator.append(saver.to_proto()), signature=())
    self.assertEqual(1, len(proto_accumulator))
    proto = proto_accumulator[0]
    save = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name))
    restore = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name))
    save_path = save(constant_op.constant(prefix))
    v1.assign(1.)
    restore(constant_op.constant(save_path))
    self.assertEqual(2., self.evaluate(v1))

    v2 = resource_variable_ops.ResourceVariable(3.)
    second_saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v2, "x"))
    second_saver.restore(save_path)
    self.assertEqual(2., self.evaluate(v2))
  def test_checkpoint_multi_device_using_localhost(self):
    with ops.device("cpu:0"):
      v0 = resource_variable_ops.ResourceVariable(0.)
    with ops.device("cpu:1"):
      v1 = resource_variable_ops.ResourceVariable(1.)
    with ops.device("cpu:2"):
      v2 = resource_variable_ops.ResourceVariable(2.)

    self.evaluate([v0.initializer, v1.initializer, v2.initializer])
    saver = functional_saver.MultiDeviceSaver(
        list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
        list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
        list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")
    self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
    self.evaluate(v0.assign(-1.))
    self.evaluate(v1.assign(-1.))
    self.evaluate(v2.assign(-1.))
    self.evaluate(
        saver.restore(constant_op.constant(prefix), self.local_options))
    self.assertEqual(0., self.evaluate(v0))
    self.assertEqual(1., self.evaluate(v1))
    self.assertEqual(2., self.evaluate(v2))

    # In graph mode, verify that the save and restore ops were set to run on
    # localhost.
    if not context.executing_eagerly():
      for op in ops.get_default_graph().get_operations():
        if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"):
          self.assertEqual(LOCALHOST, op.device)
  def test_to_proto(self):
    v1 = resource_variable_ops.ResourceVariable(2.)
    saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v1, "x"))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")

    proto_accumulator = []
    wrapped = wrap_function.wrap_function(
        lambda: proto_accumulator.append(saver.to_proto()), signature=())
    self.assertEqual(1, len(proto_accumulator))
    proto = proto_accumulator[0]
    save = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name))
    restore = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name))
    save_path = save(constant_op.constant(prefix))
    v1.assign(1.)
    restore(constant_op.constant(save_path))
    self.assertEqual(2., self.evaluate(v1))

    v2 = resource_variable_ops.ResourceVariable(3.)
    second_saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v2, "x"))
    second_saver.restore(save_path)
    self.assertEqual(2., self.evaluate(v2))
  def test_checkpoint_is_sharded_by_task(self):
    servers = [server_lib.Server.create_local_server() for _ in range(3)]
    cluster_spec = server_lib.ClusterSpec({
        "worker": [s.target[len("grpc://"):] for s in servers]})
    remote.connect_to_cluster(cluster_spec)
    with ops.device("/job:worker/task:0/cpu:0"):
      v0 = resource_variable_ops.ResourceVariable(0.)
    with ops.device("/job:worker/task:1/cpu:0"):
      v1 = resource_variable_ops.ResourceVariable(1.)
    with ops.device("/job:worker/task:2/cpu:0"):
      v2 = resource_variable_ops.ResourceVariable(2.)

    self.evaluate([v0.initializer, v1.initializer, v2.initializer])
    saver = functional_saver.MultiDeviceSaver(
        list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
        list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
        list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")
    self.evaluate(saver.save(constant_op.constant(prefix)))
    self.assertEqual(4, len(gfile.Glob(prefix + "*")))
    self.evaluate(v0.assign(-1.))
    self.evaluate(v1.assign(-1.))
    self.evaluate(v2.assign(-1.))
    self.evaluate(saver.restore(constant_op.constant(prefix)))
    self.assertEqual(0., self.evaluate(v0))
    self.assertEqual(1., self.evaluate(v1))
    self.assertEqual(2., self.evaluate(v2))
    def test_resource_variable(self):
        v1 = resource_variable_ops.ResourceVariable(2.)
        saver = functional_saver.Saver(
            saveable_object_util.saveable_objects_for_op(v1, "x"))
        prefix = os.path.join(self.get_temp_dir(), "ckpt")
        save_path = saver.save(constant_op.constant(prefix))
        v1.assign(1.)
        saver.restore(save_path)
        self.assertEqual(2., self.evaluate(v1))

        v2 = resource_variable_ops.ResourceVariable(3.)
        second_saver = functional_saver.Saver(
            saveable_object_util.saveable_objects_for_op(v2, "x"))
        second_saver.restore(save_path)
        self.assertEqual(2., self.evaluate(v2))
  def test_resource_variable(self):
    v1 = resource_variable_ops.ResourceVariable(2.)
    self.evaluate(v1.initializer)
    saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v1, "x"))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")
    self.evaluate(saver.save(constant_op.constant(prefix)))
    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
    self.evaluate(v1.assign(1.))
    self.evaluate(saver.restore(prefix))
    self.assertEqual(2., self.evaluate(v1))

    v2 = resource_variable_ops.ResourceVariable(3.)
    self.evaluate(v2.initializer)
    second_saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v2, "x"))
    self.evaluate(second_saver.restore(prefix))
    self.assertEqual(2., self.evaluate(v2))
  def test_resource_variable(self):
    v1 = resource_variable_ops.ResourceVariable(2.)
    self.evaluate(v1.initializer)
    saver = functional_saver._SingleDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v1, "x"))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")
    self.evaluate(saver.save(constant_op.constant(prefix)))
    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
    self.evaluate(v1.assign(1.))
    self.evaluate(saver.restore(prefix))
    self.assertEqual(2., self.evaluate(v1))

    v2 = resource_variable_ops.ResourceVariable(3.)
    self.evaluate(v2.initializer)
    second_saver = functional_saver._SingleDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v2, "x"))
    self.evaluate(second_saver.restore(prefix))
    self.assertEqual(2., self.evaluate(v2))
  def test_checkpoint_is_sharded_by_device(self):
    with ops.device("cpu:0"):
      v0 = resource_variable_ops.ResourceVariable(0.)
    with ops.device("cpu:1"):
      v1 = resource_variable_ops.ResourceVariable(1.)
    with ops.device("cpu:2"):
      v2 = resource_variable_ops.ResourceVariable(2.)

    self.evaluate([v0.initializer, v1.initializer, v2.initializer])
    saver = functional_saver.MultiDeviceSaver(
        list(saveable_object_util.saveable_objects_for_op(v0, "v0"))
        + list(saveable_object_util.saveable_objects_for_op(v1, "v1"))
        + list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")
    self.evaluate(saver.save(constant_op.constant(prefix)))
    self.assertEqual(4, len(gfile.Glob(prefix + "*")))
    self.evaluate(v0.assign(-1.))
    self.evaluate(v1.assign(-1.))
    self.evaluate(v2.assign(-1.))
    self.evaluate(saver.restore(constant_op.constant(prefix)))
    self.assertEqual(0., self.evaluate(v0))
    self.assertEqual(1., self.evaluate(v1))
    self.assertEqual(2., self.evaluate(v2))
示例#10
0
    def test_checkpoint_is_sharded_by_device(self):
        with ops.device("cpu:0"):
            v0 = resource_variable_ops.ResourceVariable(0.)
        with ops.device("cpu:1"):
            v1 = resource_variable_ops.ResourceVariable(1.)
        with ops.device("cpu:2"):
            v2 = resource_variable_ops.ResourceVariable(2.)

        self.evaluate([v0.initializer, v1.initializer, v2.initializer])
        saver = functional_saver.MultiDeviceSaver(
            list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
            list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
            list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
        prefix = os.path.join(self.get_temp_dir(), "ckpt")
        self.evaluate(saver.save(constant_op.constant(prefix)))
        self.assertEqual(4, len(gfile.Glob(prefix + "*")))
        self.evaluate(v0.assign(-1.))
        self.evaluate(v1.assign(-1.))
        self.evaluate(v2.assign(-1.))
        self.evaluate(saver.restore(constant_op.constant(prefix)))
        self.assertEqual(0., self.evaluate(v0))
        self.assertEqual(1., self.evaluate(v1))
        self.assertEqual(2., self.evaluate(v2))
示例#11
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
    """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
    base_type = variable.dtype.base_dtype
    # Do not colocate with variable since RestoreV2 op only runs on CPU and
    # colocation will force variable (and other ops that colocate with variable)
    # to be on CPU as well. It is okay to place the variable's initializer op on
    # CPU since it will only be run once at the start.
    with ops.device(variable.device), ops.device("/cpu:0"):
        restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec],
                                       [base_type],
                                       name=name)[0]

        names_to_saveables = saveable_object_util.op_list_to_dict([variable])
        saveable_objects = []
        for name, op in names_to_saveables.items():
            for s in saveable_object_util.saveable_objects_for_op(op, name):
                saveable_objects.append(s)

        assert len(saveable_objects) == 1  # Should be only one variable.
    init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)

    # pylint:disable=protected-access
    variable._initializer_op = init_op
    restore_op.set_shape(variable.shape)
    variable._initial_value = restore_op
示例#12
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  # Do not colocate with variable since RestoreV2 op only runs on CPU and
  # colocation will force variable (and other ops that colocate with variable)
  # to be on CPU as well. It is okay to place the variable's initializer op on
  # CPU since it will only be run once at the start.
  with ops.device(variable.device), ops.device("/cpu:0"):
    restore_op = io_ops.restore_v2(
        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]

    names_to_saveables = saveable_object_util.op_list_to_dict([variable])
    saveable_objects = []
    for name, op in names_to_saveables.items():
      for s in saveable_object_util.saveable_objects_for_op(op, name):
        saveable_objects.append(s)

    assert len(saveable_objects) == 1  # Should be only one variable.
  init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)

  # pylint:disable=protected-access
  variable._initializer_op = init_op
  restore_op.set_shape(variable.shape)
  variable._initial_value = restore_op
示例#13
0
    def _add_attributes_to_object_graph(self, trackable_objects,
                                        object_graph_proto, node_ids,
                                        object_names, object_map,
                                        call_with_mapped_captures):
        """Create SaveableObjects and corresponding SerializedTensor protos."""
        named_saveable_objects = []
        if self._saveables_cache is None:
            # No SaveableObject caching. Either we're executing eagerly, or building a
            # static save which is specialized to the current Python state.
            feed_additions = None
        else:
            # If we are caching SaveableObjects, we need to build up a feed_dict with
            # functions computing volatile Python state to be saved with the
            # checkpoint.
            feed_additions = {}
        for checkpoint_id, (trackable, object_proto) in enumerate(
                zip(trackable_objects, object_graph_proto.nodes)):
            assert node_ids[trackable] == checkpoint_id
            object_name = object_names[trackable]
            if object_map is None:
                object_to_save = trackable
            else:
                object_to_save = object_map.get(trackable, trackable)
            if self._saveables_cache is not None:
                cached_attributes = self._saveables_cache.setdefault(
                    object_to_save, {})
            else:
                cached_attributes = None

            for name, saveable_factory in (
                    object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
                attribute = object_proto.attributes.add()
                attribute.name = name
                attribute.checkpoint_key = "%s/%s/%s" % (
                    object_name, _OBJECT_ATTRIBUTES_NAME,
                    _escape_local_name(name))
                if cached_attributes is None:
                    saveables = None
                else:
                    saveables = cached_attributes.get(name, None)
                    if saveables is not None:
                        for saveable in saveables:
                            if attribute.checkpoint_key not in saveable.name:
                                # The checkpoint key for this SaveableObject is different. We
                                # need to re-create it.
                                saveables = None
                                del cached_attributes[name]
                                break
                if saveables is None:
                    if callable(saveable_factory):
                        maybe_saveable = saveable_object_util.create_saveable_object(
                            saveable_factory, attribute.checkpoint_key,
                            call_with_mapped_captures)
                    else:
                        maybe_saveable = saveable_factory
                    if isinstance(maybe_saveable,
                                  saveable_object_lib.SaveableObject):
                        saveables = (maybe_saveable, )
                    else:
                        # Figure out the name-based Saver's name for this variable. If it's
                        # already a SaveableObject we'd just get the checkpoint key back, so
                        # we leave full_name blank.
                        saver_dict = saveable_object_util.op_list_to_dict(
                            [maybe_saveable], convert_variable_to_tensor=False)
                        full_name, = saver_dict.keys()
                        saveables = tuple(
                            saveable_object_util.saveable_objects_for_op(
                                op=maybe_saveable,
                                name=attribute.checkpoint_key))
                        for saveable in saveables:
                            saveable.full_name = full_name
                    for saveable in saveables:
                        if attribute.checkpoint_key not in saveable.name:
                            raise AssertionError((
                                "The object %s produced a SaveableObject with name '%s' for "
                                "attribute '%s'. Expected a name containing '%s'."
                            ) % (trackable, name, saveable.name,
                                 attribute.checkpoint_key))
                    if cached_attributes is not None:
                        cached_attributes[name] = saveables

                optional_restore = None
                for saveable in saveables:
                    if optional_restore is None:
                        optional_restore = saveable.optional_restore
                    else:
                        optional_restore = optional_restore and saveable.optional_restore

                    if hasattr(saveable, "full_name"):
                        attribute.full_name = saveable.full_name
                    if isinstance(saveable, base.PythonStateSaveable):
                        if feed_additions is None:
                            assert self._saveables_cache is None
                            # If we're not caching saveables, then we're either executing
                            # eagerly or building a static save/restore (e.g. for a
                            # SavedModel). In either case, we should embed the current Python
                            # state in the graph rather than relying on a feed dict.
                            saveable = saveable.freeze()
                        else:
                            saveable_feed_dict = saveable.feed_dict_additions()
                            for new_feed_key in saveable_feed_dict.keys():
                                if new_feed_key in feed_additions:
                                    raise AssertionError((
                                        "The object %s tried to feed a value for the Tensor %s "
                                        "when saving, but another object is already feeding a "
                                        "value.") % (trackable, new_feed_key))
                            feed_additions.update(saveable_feed_dict)
                    named_saveable_objects.append(saveable)
                if optional_restore is None:
                    optional_restore = False
                attribute.optional_restore = optional_restore

        return named_saveable_objects, feed_additions
示例#14
0
    def _add_attributes_to_object_graph_for_saveable_objects(
            self, checkpoint_factory_map, object_graph_proto, node_ids,
            object_map, call_with_mapped_captures):
        """Create SaveableObjects and corresponding SerializedTensor protos."""
        named_saveable_objects = []
        if self._saveables_cache is None:
            # No SaveableObject caching. Either we're executing eagerly, or building a
            # static save which is specialized to the current Python state.
            feed_additions = None
        else:
            # If we are caching SaveableObjects, we need to build up a feed_dict with
            # functions computing volatile Python state to be saved with the
            # checkpoint.
            feed_additions = {}
        for trackable, factory_data_list in checkpoint_factory_map.items():
            object_proto = object_graph_proto.nodes[node_ids[trackable]]
            if self._saveables_cache is not None:
                object_to_save = _get_mapped_trackable(trackable, object_map)
                cached_attributes = self._saveables_cache.setdefault(
                    object_to_save, {})
            else:
                cached_attributes = None

            for factory_data in factory_data_list:
                attribute = object_proto.attributes.add()
                attribute.name = name = factory_data.name
                attribute.checkpoint_key = key = factory_data.checkpoint_key
                saveable_factory = factory_data.factory

                # See if we can skip saving this checkpoint key.
                saveables = cached_attributes.get(
                    name) if cached_attributes else None
                if saveables is not None:
                    for saveable in saveables:
                        if key not in saveable.name:
                            # The checkpoint key for this SaveableObject is different. We
                            # need to re-create it.
                            saveables = None
                            del cached_attributes[name]
                            break

                if saveables is None:
                    if callable(saveable_factory):
                        maybe_saveable = saveable_object_util.create_saveable_object(
                            saveable_factory, key, call_with_mapped_captures)
                    else:
                        maybe_saveable = saveable_factory
                    if isinstance(maybe_saveable,
                                  saveable_object_lib.SaveableObject):
                        saveables = (maybe_saveable, )
                    else:
                        # Figure out the name-based Saver's name for this variable. If it's
                        # already a SaveableObject we'd just get the checkpoint key back, so
                        # we leave full_name blank.
                        saver_dict = saveable_object_util.op_list_to_dict(
                            [maybe_saveable], convert_variable_to_tensor=False)
                        full_name, = saver_dict.keys()
                        saveables = tuple(
                            saveable_object_util.saveable_objects_for_op(
                                op=maybe_saveable, name=key))
                        for saveable in saveables:
                            saveable.full_name = full_name
                    for saveable in saveables:
                        if key not in saveable.name:
                            raise AssertionError(
                                f"The object {trackable} produced a SaveableObject with name "
                                f"'{saveable.name}' for attribute '{name}'. Expected a name"
                                f" containing '{key}'.")
                    if cached_attributes is not None:
                        cached_attributes[name] = saveables

                for saveable in saveables:
                    if hasattr(saveable, "full_name"):
                        attribute.full_name = saveable.full_name
                    if isinstance(saveable, base.PythonStateSaveable):
                        if feed_additions is None:
                            assert self._saveables_cache is None
                            # If we're not caching saveables, then we're either executing
                            # eagerly or building a static save/restore (e.g. for a
                            # SavedModel). In either case, we should embed the current Python
                            # state in the graph rather than relying on a feed dict.
                            saveable = saveable.freeze()
                        else:
                            saveable_feed_dict = saveable.feed_dict_additions()
                            for new_feed_key in saveable_feed_dict.keys():
                                if new_feed_key in feed_additions:
                                    raise AssertionError(
                                        f"The object {trackable} tried to feed a value for the "
                                        f"Tensor {new_feed_key} when saving, but another object "
                                        "is already feeding a value.")
                            feed_additions.update(saveable_feed_dict)
                    named_saveable_objects.append(saveable)

        return named_saveable_objects, feed_additions
示例#15
0
def _add_attributes_to_object_graph_for_saveable_objects(
        checkpoint_factory_map, object_graph_proto, node_ids, object_map,
        call_with_mapped_captures, saveables_cache):
    """Create SaveableObjects and corresponding SerializedTensor protos."""
    named_saveable_objects = []
    if saveables_cache is None:
        # No SaveableObject caching. Either we're executing eagerly, or building a
        # static save which is specialized to the current Python state.
        feed_additions = None
    else:
        # If we are caching SaveableObjects, we need to build up a feed_dict with
        # functions computing volatile Python state to be saved with the
        # checkpoint.
        feed_additions = {}
    for trackable, factory_data_list in checkpoint_factory_map.items():
        object_proto = object_graph_proto.nodes[node_ids[trackable]]
        object_to_save = _get_mapped_trackable(trackable, object_map)
        if saveables_cache is not None:
            cached_attributes = saveables_cache.setdefault(object_to_save, {})
        else:
            cached_attributes = None

        for factory_data in factory_data_list:
            name = factory_data.name
            key = factory_data.checkpoint_key
            saveable_factory = factory_data.factory

            # See if we can skip saving this checkpoint key.
            saveables = cached_attributes.get(
                name) if cached_attributes else None
            if saveables is not None:
                for saveable in saveables:
                    if key not in saveable.name:
                        # The checkpoint key for this SaveableObject is different. We
                        # need to re-create it.
                        saveables = None
                        del cached_attributes[name]
                        break

            if saveables is None:
                if callable(saveable_factory):
                    maybe_saveable = saveable_object_util.create_saveable_object(
                        saveable_factory, key, call_with_mapped_captures)
                else:
                    maybe_saveable = saveable_factory
                if isinstance(maybe_saveable,
                              saveable_object_lib.SaveableObject):
                    saveables = (maybe_saveable, )
                else:
                    saveables = tuple(
                        saveable_object_util.saveable_objects_for_op(
                            op=maybe_saveable, name=key))
                for saveable in saveables:
                    if key not in saveable.name:
                        raise AssertionError(
                            f"The object {trackable} produced a SaveableObject with name "
                            f"'{saveable.name}' for attribute '{name}'. Expected a name"
                            f" containing '{key}'.")
                if cached_attributes is not None:
                    cached_attributes[name] = saveables

            for saveable in saveables:
                if isinstance(saveable, python_state.PythonStateSaveable):
                    if feed_additions is None:
                        assert saveables_cache is None
                        # If we're not caching saveables, then we're either executing
                        # eagerly or building a static save/restore (e.g. for a
                        # SavedModel). In either case, we should embed the current Python
                        # state in the graph rather than relying on a feed dict.
                        saveable = saveable.freeze()
                    else:
                        saveable_feed_dict = saveable.feed_dict_additions()
                        for new_feed_key in saveable_feed_dict.keys():
                            if new_feed_key in feed_additions:
                                raise AssertionError(
                                    f"The object {trackable} tried to feed a value for the "
                                    f"Tensor {new_feed_key} when saving, but another object "
                                    "is already feeding a value.")
                        feed_additions.update(saveable_feed_dict)
                named_saveable_objects.append(saveable)

            # Update the object proto.
            # For updated Trackables that override serialize_to_tensors, add an
            # attribute for each tensor that is serialized.
            # For Trackables that have SaveableObjects or a legacy saveable name,
            # add a single attribute to the proto.
            if (isinstance(saveables[0],
                           saveable_object_util.TrackableSaveable) and
                    saveable_compat.get_saveable_name(object_to_save) is None):
                for local_name, local_key in (
                        saveables[0].get_proto_names_and_checkpoint_keys()):
                    object_proto.attributes.add(
                        name=local_name,
                        checkpoint_key=local_key,
                        full_name=_get_full_name(object_to_save))
            else:
                object_proto.attributes.add(
                    name=name,
                    checkpoint_key=key,
                    full_name=_get_full_name(object_to_save))

    return named_saveable_objects, feed_additions