def test_registration(self): registration.register_checkpoint_saver( package="Testing", name="test_predicate", predicate=lambda x: hasattr(x, "check_attr"), save_fn=lambda: "save", restore_fn=lambda: "restore") x = base.Trackable() self.assertIsNone(registration.get_registered_saver_name(x)) x.check_attr = 1 saver_name = registration.get_registered_saver_name(x) self.assertEqual(saver_name, "Testing.test_predicate") self.assertEqual(registration.get_save_function(saver_name)(), "save") self.assertEqual( registration.get_restore_function(saver_name)(), "restore") registration.validate_restore_function(x, "Testing.test_predicate") with self.assertRaisesRegex(ValueError, "saver cannot be found"): registration.validate_restore_function(x, "Invalid.name") x2 = base.Trackable() with self.assertRaisesRegex(ValueError, "saver cannot be used"): registration.validate_restore_function(x2, "Testing.test_predicate")
def get_checkpoint_factories_and_keys(object_names, object_map=None): """Gets a map of saveable factories and corresponding checkpoint keys. Args: object_names: a dictionary that maps `Trackable` objects to auto-generated string names. object_map: a dictionary mapping `Trackable` to copied `Trackable` objects. The copied objects are generated from `Trackable._map_resources()` which copies the object into another graph. Generally only resource objects (e.g. Variables, Tables) will be in this map. Returns: A tuple of ( Dictionary mapping trackable -> list of _CheckpointFactoryData, Dictionary mapping registered saver name -> {object name -> trackable}) """ checkpoint_factory_map = object_identity.ObjectIdentityDictionary() unmapped_registered_savers = collections.defaultdict(dict) for trackable, object_name in object_names.items(): # object_to_save is only used to retrieve the saving functionality. For keys # and other data, use the original `trackable`. object_to_save = util.get_mapped_trackable(trackable, object_map) saver_name = registration.get_registered_saver_name(object_to_save) if saver_name: # Add the original trackable instead of `object_to_save` to the returned # dict because the original is needed for writing the object proto. unmapped_registered_savers[saver_name][object_name] = trackable else: checkpoint_factory_map[trackable] = [] for name, saveable_factory in ( saveable_object_util.saveable_objects_from_trackable( object_to_save).items()): # pylint: disable=protected-access # Retrieve the legacy saveable name (for compatibility purposes during # SaveableObject deprecation) key_suffix = saveable_compat.get_saveable_name( object_to_save) or name checkpoint_key = trackable_utils.checkpoint_key( object_name, key_suffix) if not saveable_compat.force_checkpoint_conversion_enabled(): # Make sure the set the name as the legacy saveable name if there # is one (only when checkpoint conversion is diabled) name = key_suffix checkpoint_factory_map[trackable].append( _CheckpointFactoryData(factory=saveable_factory, name=name, checkpoint_key=checkpoint_key)) return checkpoint_factory_map, unmapped_registered_savers
def gather_ops_or_named_saveables(self): """Looks up or creates SaveableObjects which don't have cached ops.""" # pylint:disable=g-import-not-at-top # There are circular dependencies between Trackable and SaveableObject, # so we must import it here. # TODO(b/224069573): Remove this code from Trackable. from tensorflow.python.training.saving import saveable_object_util # pylint:enable=g-import-not-at-top if not self.object_proto.attributes: return [], {}, [], {} saveable_factories = saveable_object_util.saveable_objects_from_trackable( self.trackable) if saveable_factories.keys() == { trackable_utils.SERIALIZE_TO_TENSORS_NAME }: return self._create_serialize_to_tensor_saveable( saveable_factories) elif saveable_factories: return self._create_saveables_by_attribute_name(saveable_factories) elif self.object_proto.attributes: # The checkpoint may have a serialized tensor recorded, but the # Trackable appears to have no tensors to serialize/restore. When this # happens, it means that the Trackable has migrated to the registered # checkpoint functionality (TPUEmbedding is an example of this). saver_name = registration.get_registered_saver_name(self.trackable) if saver_name: registered_savers = {} registered_savers[saver_name] = { # For now, set the Trackable's object name to the first checkpoint # key that is stored in checkpoint. If there is a use case that # requires the other keys, then we can take another look at this. self.object_proto.attributes[0].checkpoint_key: self.trackable } return {}, [], [], registered_savers # If no registered savers were found, then it means that one or more # serialized tensors were never used. for serialized_tensor in self.object_proto.attributes: self._checkpoint.unused_attributes.setdefault( self._proto_id, []).append(serialized_tensor.name) return {}, [], [], {}
def get_checkpoint_factories_and_keys(object_names, object_map=None): """Gets a map of saveable factories and corresponding checkpoint keys. Args: object_names: a dictionary that maps `Trackable` objects to auto-generated string names. object_map: a dictionary mapping `Trackable` to copied `Trackable` objects. The copied objects are generated from `Trackable._map_resources()` which copies the object into another graph. Generally only resource objects (e.g. Variables, Tables) will be in this map. Returns: A tuple of ( Dictionary mapping trackable -> list of _CheckpointFactoryData, Dictionary mapping registered saver name -> {object name -> trackable}) """ checkpoint_factory_map = object_identity.ObjectIdentityDictionary() registered_savers = collections.defaultdict(dict) for trackable, object_name in object_names.items(): # object_to_save is only used to retrieve the saving functionality. For keys # and other data, use the original `trackable`. object_to_save = _get_mapped_trackable(trackable, object_map) saver_name = registration.get_registered_saver_name(object_to_save) if saver_name: registered_savers[saver_name][object_name] = trackable else: checkpoint_factory_map[trackable] = [] for name, saveable_factory in ( saveable_object_util.saveable_objects_from_trackable( object_to_save).items()): # pylint: disable=protected-access checkpoint_key = trackable_utils.checkpoint_key( object_name, name) checkpoint_factory_map[trackable].append( _CheckpointFactoryData(factory=saveable_factory, name=name, checkpoint_key=checkpoint_key)) return checkpoint_factory_map, registered_savers
def gather_ops_or_named_saveables(self): """Looks up or creates SaveableObjects which don't have cached ops. Returns: A tuple of ( existing_restore_ops: list, named_saveables: dict, python_positions: list, registered_savers: dict) """ # pylint:disable=g-import-not-at-top # There are circular dependencies between Trackable and SaveableObject, # so we must import it here. # TODO(b/224069573): Remove this code from Trackable. from tensorflow.python.training.saving import saveable_object_util # pylint:enable=g-import-not-at-top recorded_registered_saver = self.get_registered_saver_name() if not (self.object_proto.attributes or recorded_registered_saver): return [], {}, [], {} existing_restore_ops = [] named_saveables = {} python_positions = [] registered_savers = collections.defaultdict(dict) saveable_factories = saveable_object_util.saveable_objects_from_trackable( self.trackable) saver_name = registration.get_registered_saver_name(self.trackable) if recorded_registered_saver: if not self.skip_restore: name = self.object_proto.registered_saver.object_name registered_savers[recorded_registered_saver][ name] = self.trackable # Else: Skip restoration of this Trackable. This skip only happens if the # registered saver has enabled `option_restore`. Otherwise, an error would # have been raised at `self.get_registered_saver_name()`. elif saver_name: # In this case, the checkpoint has a recorded serialized tensor but no # registered saver, while the Trackable loading the checkpoint has # migrated to the registered checkpoint functionality (TPUEmbedding is an # example of this). # Set the Trackable's object name to the first checkpoint key that is # stored in checkpoint. If there is a use case that requires the other # keys, then we can take another look at this. registered_savers[saver_name] = { self.object_proto.attributes[0].checkpoint_key: self.trackable } elif isinstance(self.trackable, python_state.PythonState): python_positions.append(self) elif saveable_factories.keys() == { trackable_utils.SERIALIZE_TO_TENSORS_NAME }: existing_restore_ops, named_saveables = ( self._create_serialize_to_tensor_saveable(saveable_factories)) elif saveable_factories: existing_restore_ops, named_saveables = ( self._create_saveables_by_attribute_name(saveable_factories)) else: # If no registered savers were found, then it means that one or more # serialized tensors were never used. for serialized_tensor in self.object_proto.attributes: self._checkpoint.unused_attributes.setdefault( self._proto_id, []).append(serialized_tensor.name) return (existing_restore_ops, named_saveables, python_positions, registered_savers)