def test_register_twice(self): identifier = "foo" predicate = lambda x: isinstance(x, int) versions = [ revived_types.VersionedTypeRegistration( object_factory=lambda _: 1, version=1, min_producer_version=1, min_consumer_version=1), ] revived_types.register_revived_type(identifier, predicate, versions) with self.assertRaisesRegex(AssertionError, "Duplicate registrations"): revived_types.register_revived_type(identifier, predicate, versions)
key: value for key, value in self.items() if isinstance(value, (def_function.Function, defun.ConcreteFunction)) } revived_types.register_revived_type( "signature_map", lambda obj: isinstance(obj, _SignatureMap), versions=[ revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the trackable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _SignatureMap(), version=1, min_producer_version=1, min_consumer_version=1, setter=_SignatureMap._add_signature # pylint: disable=protected-access ) ]) def create_signature_map(signatures): """Creates an object containing `signatures`.""" signature_map = _SignatureMap() for name, func in signatures.items(): # This true of any signature that came from canonicalize_signatures. Here as # a sanity check on saving; crashing on load (e.g. in _add_signature) would # be more problematic in case future export changes violated these # assertions.
from tensorflow.python.saved_model import revived_types from tensorflow.python.trackable import autotrackable class CustomTestClass(autotrackable.AutoTrackable): def __init__(self, version): self.version = version revived_types.register_revived_type( "test_type", lambda obj: isinstance(obj, CustomTestClass), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(1), version=1, min_producer_version=1, min_consumer_version=1), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(2), version=2, min_producer_version=2, min_consumer_version=1), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(3), version=3, min_producer_version=3, min_consumer_version=2), revived_types.VersionedTypeRegistration( object_factory=lambda _: CustomTestClass(4), version=4,
for key, value in self.items() if _is_function(value) } def _is_function(x): return isinstance(x, (def_function.Function, defun.ConcreteFunction)) revived_types.register_revived_type( "trackable_dict_wrapper", lambda obj: isinstance(obj, _DictWrapper), versions=[ revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the trackable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _DictWrapper({}), version=1, min_producer_version=1, min_consumer_version=1, setter=operator.setitem) ]) def _set_list_item(list_object, index_string, value): item_index = int(index_string) if len(list_object) <= item_index: list_object.extend([None] * (1 + item_index - len(list_object))) list_object[item_index] = value revived_types.register_revived_type( "trackable_list_wrapper",
@property def trainable_variables(self): return self._trainable_variables # Registers the Snapshot object above such that when it is restored by # tf.saved_model.load it will be restored as a Snapshot. This is important # because it allows us to expose the __call__, and *_variables properties. revived_types.register_revived_type( 'acme_snapshot', lambda obj: isinstance(obj, Snapshot), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda proto: Snapshot(), version=1, min_producer_version=1, min_consumer_version=1, setter=setattr, ) ]) def make_snapshot(module: snt.Module): """Create a thin wrapper around a module to make it snapshottable.""" # Get the input signature as long as it has been created. input_signature = _get_input_signature(module) if input_signature is None: raise ValueError( ('module instance "{}" has no input_signature attribute, ' 'which is required for snapshotting; run ' 'create_variables to add this annotation.').format(module.name))
optimizer object iself (e.g. through `apply_gradients`). """ # TODO(allenl): Make the restored optimizer functional by tracing its apply # methods. def __init__(self): super(RestoredOptimizer, self).__init__("RestoredOptimizer") self._hypers_created = True def get_config(self): # TODO(allenl): Save and restore the Optimizer's config raise NotImplementedError( "Restoring functional Optimzers from SavedModels is not currently " "supported. Please file a feature request if this limitation bothers " "you.") revived_types.register_revived_type( "optimizer", lambda obj: isinstance(obj, OptimizerV2), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda proto: RestoredOptimizer(), version=1, min_producer_version=1, min_consumer_version=1, setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access ) ])
partition_strategy='mod', name=None, validate_indices=True, max_norm=None): if isinstance(params, list): params = params[0] return embedding_ops.embedding_lookup(params.variables, ids, partition_strategy, name, validate_indices, max_norm) def _raise_when_load(_): # We don't have serialization and deserialization mechanisms for # `ShardedVariable` in 2.x style save/load yet. raise ValueError( 'Loading a saved_model containing ShardedVariable via ' '`tf.saved_model.load` is not supported. If the model is built using ' 'Keras, please use `tf.keras.models.load_model` instead.') revived_types.register_revived_type( '_tf_distribute_sharded_variable', lambda obj: isinstance(obj, ShardedVariable), versions=[ revived_types.VersionedTypeRegistration( object_factory=_raise_when_load, version=0, min_producer_version=0, min_consumer_version=0) ])
# so we'll throw an exception on save. self._non_append_mutation = True del self._storage[key] self._update_snapshot() def __repr__(self): return "DictWrapper(%s)" % (repr(self._storage), ) def __hash__(self): raise TypeError("unhashable type: 'DictWrapper'") def __eq__(self, other): return self._storage == getattr(other, "_storage", other) def update(self, *args, **kwargs): for key, value in dict(*args, **kwargs).items(): self[key] = value revived_types.register_revived_type( "checkpointable_dict_wrapper", lambda obj: isinstance(obj, _DictWrapper), versions=[ revived_types.VersionedTypeRegistration( object_factory=lambda _: _DictWrapper({}), version=1, min_producer_version=1, min_consumer_version=1, setter=operator.setitem) ])
def __eq__(self, other): return self._storage == getattr(other, "_storage", other) def update(self, *args, **kwargs): for key, value in dict(*args, **kwargs).items(): self[key] = value revived_types.register_revived_type( "checkpointable_dict_wrapper", lambda obj: isinstance(obj, _DictWrapper), versions=[revived_types.VersionedTypeRegistration( # Standard dependencies are enough to reconstruct the checkpointable # items in dictionaries, so we don't need to save any extra information. object_factory=lambda proto: _DictWrapper({}), version=1, min_producer_version=1, min_consumer_version=1, setter=operator.setitem, getter=_DictWrapper.get, attribute_extractor=lambda obj: obj.keys())]) def _set_list_item(list_object, index_string, value): item_index = int(index_string) if len(list_object) <= item_index: list_object.extend([None] * (1 + item_index - len(list_object))) list_object[item_index] = value def _list_getter(obj, item, default=None): index = int(item)