def test_register_twice(self): identifier = "foo" predicate = lambda x: isinstance(x, int) versions = [ revived_types.VersionedTypeRegistration( object_factory=lambda _: 1, version=1, min_producer_version=1, min_consumer_version=1), ] revived_types.register_revived_type(identifier, predicate, versions) with self.assertRaisesRegex(AssertionError, "Duplicate registrations"): revived_types.register_revived_type(identifier, predicate, versions)
def _list_functions_for_serialization(self, unused_serialization_cache): return { key: value for key, value in self.items() if isinstance(value, (def_function.Function, defun.ConcreteFunction)) } revived_types.register_revived_type( "signature_map", lambda obj: isinstance(obj, _SignatureMap), versions=[ revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the trackable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _SignatureMap(), version=1, min_producer_version=1, min_consumer_version=1, setter=_SignatureMap._add_signature # pylint: disable=protected-access ) ]) def create_signature_map(signatures): """Creates an object containing `signatures`.""" signature_map = _SignatureMap() for name, func in signatures.items(): # This true of any signature that came from canonicalize_signatures. Here as # a sanity check on saving; crashing on load (e.g. in _add_signature) would # be more problematic in case future export changes violated these
self.version = version revived_types.register_revived_type( "test_type", lambda obj: isinstance(obj, CustomTestClass), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(1), version=1, min_producer_version=1, min_consumer_version=1), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(2), version=2, min_producer_version=2, min_consumer_version=1), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(3), version=3, min_producer_version=3, min_consumer_version=2), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(4), version=4, min_producer_version=4, min_consumer_version=2, bad_consumers=[3]), ]) class RegistrationMatchingTest(test.TestCase):
key: value for key, value in self.items() if _is_function(value) } def _is_function(x): return isinstance(x, (def_function.Function, defun.ConcreteFunction)) revived_types.register_revived_type( "trackable_dict_wrapper", lambda obj: isinstance(obj, _DictWrapper), versions=[ revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the trackable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _DictWrapper({}), version=1, min_producer_version=1, min_consumer_version=1, setter=operator.setitem) ]) def _set_list_item(list_object, index_string, value): item_index = int(index_string) if len(list_object) <= item_index: list_object.extend([None] * (1 + item_index - len(list_object))) list_object[item_index] = value
return self._variables @property def trainable_variables(self): return self._trainable_variables # Registers the Snapshot object above such that when it is restored by # tf.saved_model.load it will be restored as a Snapshot. This is important # because it allows us to expose the __call__, and *_variables properties. revived_types.register_revived_type( 'acme_snapshot', lambda obj: isinstance(obj, Snapshot), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda proto: Snapshot(), version=1, min_producer_version=1, min_consumer_version=1, setter=setattr, ) ]) def make_snapshot(module: snt.Module): """Create a thin wrapper around a module to make it snapshottable.""" # Get the input signature as long as it has been created. input_signature = _get_input_signature(module) if input_signature is None: raise ValueError( ('module instance "{}" has no input_signature attribute, ' 'which is required for snapshotting; run '
optimizer object iself (e.g. through `apply_gradients`). """ # TODO(allenl): Make the restored optimizer functional by tracing its apply # methods. def __init__(self): super(RestoredOptimizer, self).__init__("RestoredOptimizer") self._hypers_created = True def get_config(self): # TODO(allenl): Save and restore the Optimizer's config raise NotImplementedError( "Restoring functional Optimzers from SavedModels is not currently " "supported. Please file a feature request if this limitation bothers " "you.") revived_types.register_revived_type( "optimizer", lambda obj: isinstance(obj, OptimizerV2), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda proto: RestoredOptimizer(), version=1, min_producer_version=1, min_consumer_version=1, setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access ) ])
partition_strategy='mod', name=None, validate_indices=True, max_norm=None): if isinstance(params, list): params = params[0] return embedding_ops.embedding_lookup(params.variables, ids, partition_strategy, name, validate_indices, max_norm) def _raise_when_load(_): # We don't have serialization and deserialization mechanisms for # `ShardedVariable` in 2.x style save/load yet. raise ValueError( 'Loading a saved_model containing ShardedVariable via ' '`tf.saved_model.load` is not supported. If the model is built using ' 'Keras, please use `tf.keras.models.load_model` instead.') revived_types.register_revived_type( '_tf_distribute_sharded_variable', lambda obj: isinstance(obj, ShardedVariable), versions=[ revived_types.VersionedTypeRegistration( object_factory=_raise_when_load, version=0, min_producer_version=0, min_consumer_version=0) ])
return { key: value for key, value in self.items() if _is_function(value) } def _is_function(x): return isinstance(x, (def_function.Function, defun.ConcreteFunction)) revived_types.register_revived_type( "trackable_dict_wrapper", lambda obj: isinstance(obj, _DictWrapper), versions=[revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the trackable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _DictWrapper({}), version=1, min_producer_version=1, min_consumer_version=1, setter=operator.setitem)]) def _set_list_item(list_object, index_string, value): item_index = int(index_string) if len(list_object) <= item_index: list_object.extend([None] * (1 + item_index - len(list_object))) list_object[item_index] = value revived_types.register_revived_type(
# so we'll throw an exception on save. self._non_append_mutation = True del self._storage[key] self._update_snapshot() def __repr__(self): return "DictWrapper(%s)" % (repr(self._storage), ) def __hash__(self): raise TypeError("unhashable type: 'DictWrapper'") def __eq__(self, other): return self._storage == getattr(other, "_storage", other) def update(self, *args, **kwargs): for key, value in dict(*args, **kwargs).items(): self[key] = value revived_types.register_revived_type( "checkpointable_dict_wrapper", lambda obj: isinstance(obj, _DictWrapper), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda _: _DictWrapper({}), version=1, min_producer_version=1, min_consumer_version=1, setter=operator.setitem) ])
def __repr__(self): return "_SignatureMap({})".format(self._signatures) def _list_functions_for_serialization(self): return { key: value for key, value in self.items() if isinstance(value, (def_function.Function, defun.ConcreteFunction)) } revived_types.register_revived_type( "signature_map", lambda obj: isinstance(obj, _SignatureMap), versions=[revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the trackable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _SignatureMap(), version=1, min_producer_version=1, min_consumer_version=1, setter=_SignatureMap._add_signature # pylint: disable=protected-access )]) def create_signature_map(signatures): """Creates an object containing `signatures`.""" signature_map = _SignatureMap() for name, func in signatures.items(): # This true of any signature that came from canonicalize_signatures. Here as # a sanity check on saving; crashing on load (e.g. in _add_signature) would # be more problematic in case future export changes violated these # assertions.
class CustomTestClass(tracking.AutoTrackable): def __init__(self, version): self.version = version revived_types.register_revived_type( "test_type", lambda obj: isinstance(obj, CustomTestClass), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(1), version=1, min_producer_version=1, min_consumer_version=1), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(2), version=2, min_producer_version=2, min_consumer_version=1), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(3), version=3, min_producer_version=3, min_consumer_version=2), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(4), version=4, min_producer_version=4, min_consumer_version=2, bad_consumers=[3]), ] ) class RegistrationMatchingTest(test.TestCase): def test_save_typecheck(self): self.assertIs(revived_types.serialize(tracking.AutoTrackable()), None)
raise TypeError("unhashable type: 'DictWrapper'") def __eq__(self, other): return self._storage == getattr(other, "_storage", other) def update(self, *args, **kwargs): for key, value in dict(*args, **kwargs).items(): self[key] = value revived_types.register_revived_type( "checkpointable_dict_wrapper", lambda obj: isinstance(obj, _DictWrapper), versions=[revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the checkpointable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _DictWrapper({}), version=1, min_producer_version=1, min_consumer_version=1, setter=operator.setitem, getter=_DictWrapper.get, attribute_extractor=lambda obj: obj.keys())]) def _set_list_item(list_object, index_string, value): item_index = int(index_string) if len(list_object) <= item_index: list_object.extend([None] * (1 + item_index - len(list_object))) list_object[item_index] = value
Holds slot variables and hyperparameters when an optimizer is restored from a SavedModel. These variables may be referenced in functions along with ops created by the original optimizer, but currently we do not support using the optimizer object iself (e.g. through `apply_gradients`). """ # TODO(allenl): Make the restored optimizer functional by tracing its apply # methods. def __init__(self): super(_RestoredOptimizer, self).__init__("_RestoredOptimizer") self._hypers_created = True def get_config(self): # TODO(allenl): Save and restore the Optimizer's config raise NotImplementedError( "Restoring functional Optimzers from SavedModels is not currently " "supported. Please file a feature request if this limitation bothers " "you.") revived_types.register_revived_type( "optimizer", lambda obj: isinstance(obj, OptimizerV2), versions=[revived_types.VersionedTypeRegistration( object_factory=lambda proto: _RestoredOptimizer(), version=1, min_producer_version=1, min_consumer_version=1, setter=_RestoredOptimizer._set_hyper # pylint: disable=protected-access )])