def test_registration(self):
        registration.register_checkpoint_saver(
            package="Testing",
            name="test_predicate",
            predicate=lambda x: hasattr(x, "check_attr"),
            save_fn=lambda: "save",
            restore_fn=lambda: "restore")
        x = base.Trackable()
        self.assertIsNone(registration.get_registered_saver_name(x))

        x.check_attr = 1
        saver_name = registration.get_registered_saver_name(x)
        self.assertEqual(saver_name, "Testing.test_predicate")

        self.assertEqual(registration.get_save_function(saver_name)(), "save")
        self.assertEqual(
            registration.get_restore_function(saver_name)(), "restore")

        registration.validate_restore_function(x, "Testing.test_predicate")
        with self.assertRaisesRegex(ValueError, "saver cannot be found"):
            registration.validate_restore_function(x, "Invalid.name")
        x2 = base.Trackable()
        with self.assertRaisesRegex(ValueError, "saver cannot be used"):
            registration.validate_restore_function(x2,
                                                   "Testing.test_predicate")
예제 #2
0
def get_checkpoint_factories_and_keys(object_names, object_map=None):
    """Gets a map of saveable factories and corresponding checkpoint keys.

  Args:
    object_names: a dictionary that maps `Trackable` objects to auto-generated
      string names.
    object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
      The copied objects are generated from `Trackable._map_resources()` which
      copies the object into another graph. Generally only resource objects
      (e.g. Variables, Tables) will be in this map.

  Returns:
    A tuple of (
      Dictionary mapping trackable -> list of _CheckpointFactoryData,
      Dictionary mapping registered saver name -> {object name -> trackable})
  """
    checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
    unmapped_registered_savers = collections.defaultdict(dict)
    for trackable, object_name in object_names.items():
        # object_to_save is only used to retrieve the saving functionality. For keys
        # and other data, use the original `trackable`.
        object_to_save = util.get_mapped_trackable(trackable, object_map)

        saver_name = registration.get_registered_saver_name(object_to_save)
        if saver_name:
            # Add the original trackable instead of `object_to_save` to the returned
            # dict because the original is needed for writing the object proto.
            unmapped_registered_savers[saver_name][object_name] = trackable
        else:
            checkpoint_factory_map[trackable] = []
            for name, saveable_factory in (
                    saveable_object_util.saveable_objects_from_trackable(
                        object_to_save).items()):  # pylint: disable=protected-access
                # Retrieve the legacy saveable name (for compatibility purposes during
                # SaveableObject deprecation)

                key_suffix = saveable_compat.get_saveable_name(
                    object_to_save) or name
                checkpoint_key = trackable_utils.checkpoint_key(
                    object_name, key_suffix)

                if not saveable_compat.force_checkpoint_conversion_enabled():
                    # Make sure the set the name as the legacy saveable name if there
                    # is one (only when checkpoint conversion is diabled)
                    name = key_suffix

                checkpoint_factory_map[trackable].append(
                    _CheckpointFactoryData(factory=saveable_factory,
                                           name=name,
                                           checkpoint_key=checkpoint_key))
    return checkpoint_factory_map, unmapped_registered_savers
예제 #3
0
    def gather_ops_or_named_saveables(self):
        """Looks up or creates SaveableObjects which don't have cached ops."""
        # pylint:disable=g-import-not-at-top
        # There are circular dependencies between Trackable and SaveableObject,
        # so we must import it here.
        # TODO(b/224069573): Remove this code from Trackable.
        from tensorflow.python.training.saving import saveable_object_util
        # pylint:enable=g-import-not-at-top

        if not self.object_proto.attributes:
            return [], {}, [], {}

        saveable_factories = saveable_object_util.saveable_objects_from_trackable(
            self.trackable)
        if saveable_factories.keys() == {
                trackable_utils.SERIALIZE_TO_TENSORS_NAME
        }:
            return self._create_serialize_to_tensor_saveable(
                saveable_factories)
        elif saveable_factories:
            return self._create_saveables_by_attribute_name(saveable_factories)
        elif self.object_proto.attributes:
            # The checkpoint may have a serialized tensor recorded, but the
            # Trackable appears to have no tensors to serialize/restore. When this
            # happens, it means that the Trackable has migrated to the registered
            # checkpoint functionality (TPUEmbedding is an example of this).
            saver_name = registration.get_registered_saver_name(self.trackable)
            if saver_name:
                registered_savers = {}
                registered_savers[saver_name] = {
                    # For now, set the Trackable's object name to the first checkpoint
                    # key that is stored in checkpoint. If there is a use case that
                    # requires the other keys, then we can take another look at this.
                    self.object_proto.attributes[0].checkpoint_key:
                    self.trackable
                }
                return {}, [], [], registered_savers

            # If no registered savers were found, then it means that one or more
            # serialized tensors were never used.
            for serialized_tensor in self.object_proto.attributes:
                self._checkpoint.unused_attributes.setdefault(
                    self._proto_id, []).append(serialized_tensor.name)
        return {}, [], [], {}
예제 #4
0
def get_checkpoint_factories_and_keys(object_names, object_map=None):
    """Gets a map of saveable factories and corresponding checkpoint keys.

  Args:
    object_names: a dictionary that maps `Trackable` objects to auto-generated
      string names.
    object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
      The copied objects are generated from `Trackable._map_resources()` which
      copies the object into another graph. Generally only resource objects
      (e.g. Variables, Tables) will be in this map.
  Returns:
    A tuple of (
      Dictionary mapping trackable -> list of _CheckpointFactoryData,
      Dictionary mapping registered saver name -> {object name -> trackable})
  """
    checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
    registered_savers = collections.defaultdict(dict)
    for trackable, object_name in object_names.items():
        # object_to_save is only used to retrieve the saving functionality. For keys
        # and other data, use the original `trackable`.
        object_to_save = _get_mapped_trackable(trackable, object_map)

        saver_name = registration.get_registered_saver_name(object_to_save)
        if saver_name:
            registered_savers[saver_name][object_name] = trackable
        else:
            checkpoint_factory_map[trackable] = []
            for name, saveable_factory in (
                    saveable_object_util.saveable_objects_from_trackable(
                        object_to_save).items()):  # pylint: disable=protected-access
                checkpoint_key = trackable_utils.checkpoint_key(
                    object_name, name)
                checkpoint_factory_map[trackable].append(
                    _CheckpointFactoryData(factory=saveable_factory,
                                           name=name,
                                           checkpoint_key=checkpoint_key))
    return checkpoint_factory_map, registered_savers
예제 #5
0
    def gather_ops_or_named_saveables(self):
        """Looks up or creates SaveableObjects which don't have cached ops.

    Returns:
      A tuple of (
          existing_restore_ops: list,
          named_saveables: dict,
          python_positions: list,
          registered_savers: dict)
    """
        # pylint:disable=g-import-not-at-top
        # There are circular dependencies between Trackable and SaveableObject,
        # so we must import it here.
        # TODO(b/224069573): Remove this code from Trackable.
        from tensorflow.python.training.saving import saveable_object_util
        # pylint:enable=g-import-not-at-top

        recorded_registered_saver = self.get_registered_saver_name()
        if not (self.object_proto.attributes or recorded_registered_saver):
            return [], {}, [], {}

        existing_restore_ops = []
        named_saveables = {}
        python_positions = []
        registered_savers = collections.defaultdict(dict)

        saveable_factories = saveable_object_util.saveable_objects_from_trackable(
            self.trackable)
        saver_name = registration.get_registered_saver_name(self.trackable)

        if recorded_registered_saver:
            if not self.skip_restore:
                name = self.object_proto.registered_saver.object_name
                registered_savers[recorded_registered_saver][
                    name] = self.trackable
            # Else: Skip restoration of this Trackable. This skip only happens if the
            # registered saver has enabled `option_restore`. Otherwise, an error would
            # have been raised at `self.get_registered_saver_name()`.
        elif saver_name:
            # In this case, the checkpoint has a recorded serialized tensor but no
            # registered saver, while the Trackable loading the checkpoint has
            # migrated to the registered checkpoint functionality (TPUEmbedding is an
            # example of this).

            # Set the Trackable's object name to the first checkpoint key that is
            # stored in checkpoint. If there is a use case that requires the other
            # keys, then we can take another look at this.
            registered_savers[saver_name] = {
                self.object_proto.attributes[0].checkpoint_key: self.trackable
            }
        elif isinstance(self.trackable, python_state.PythonState):
            python_positions.append(self)
        elif saveable_factories.keys() == {
                trackable_utils.SERIALIZE_TO_TENSORS_NAME
        }:
            existing_restore_ops, named_saveables = (
                self._create_serialize_to_tensor_saveable(saveable_factories))
        elif saveable_factories:
            existing_restore_ops, named_saveables = (
                self._create_saveables_by_attribute_name(saveable_factories))
        else:
            # If no registered savers were found, then it means that one or more
            # serialized tensors were never used.
            for serialized_tensor in self.object_proto.attributes:
                self._checkpoint.unused_attributes.setdefault(
                    self._proto_id, []).append(serialized_tensor.name)
        return (existing_restore_ops, named_saveables, python_positions,
                registered_savers)