Example #1
0
 def test_register_twice(self):
   identifier = "foo"
   predicate = lambda x: isinstance(x, int)
   versions = [
       revived_types.VersionedTypeRegistration(
           object_factory=lambda _: 1,
           version=1, min_producer_version=1,
           min_consumer_version=1),
   ]
   revived_types.register_revived_type(identifier, predicate, versions)
   with self.assertRaisesRegex(AssertionError, "Duplicate registrations"):
     revived_types.register_revived_type(identifier, predicate, versions)
            key: value
            for key, value in self.items()
            if isinstance(value, (def_function.Function,
                                  defun.ConcreteFunction))
        }


revived_types.register_revived_type(
    "signature_map",
    lambda obj: isinstance(obj, _SignatureMap),
    versions=[
        revived_types.VersionedTypeRegistration(
            # Standard dependencies are enough to reconstruct the trackable
            # items in dictionaries, so we don't need to save any extra information.
            object_factory=lambda proto: _SignatureMap(),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=_SignatureMap._add_signature  # pylint: disable=protected-access
        )
    ])


def create_signature_map(signatures):
    """Creates an object containing `signatures`."""
    signature_map = _SignatureMap()
    for name, func in signatures.items():
        # This true of any signature that came from canonicalize_signatures. Here as
        # a sanity check on saving; crashing on load (e.g. in _add_signature) would
        # be more problematic in case future export changes violated these
        # assertions.
Example #3
0
from tensorflow.python.saved_model import revived_types
from tensorflow.python.trackable import autotrackable


class CustomTestClass(autotrackable.AutoTrackable):
    def __init__(self, version):
        self.version = version


revived_types.register_revived_type(
    "test_type",
    lambda obj: isinstance(obj, CustomTestClass),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(1),
            version=1,
            min_producer_version=1,
            min_consumer_version=1),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(2),
            version=2,
            min_producer_version=2,
            min_consumer_version=1),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(3),
            version=3,
            min_producer_version=3,
            min_consumer_version=2),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(4),
            version=4,
Example #4
0
            for key, value in self.items() if _is_function(value)
        }


def _is_function(x):
    return isinstance(x, (def_function.Function, defun.ConcreteFunction))


revived_types.register_revived_type(
    "trackable_dict_wrapper",
    lambda obj: isinstance(obj, _DictWrapper),
    versions=[
        revived_types.VersionedTypeRegistration(
            # Standard dependencies are enough to reconstruct the trackable
            # items in dictionaries, so we don't need to save any extra information.
            object_factory=lambda proto: _DictWrapper({}),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=operator.setitem)
    ])


def _set_list_item(list_object, index_string, value):
    item_index = int(index_string)
    if len(list_object) <= item_index:
        list_object.extend([None] * (1 + item_index - len(list_object)))
    list_object[item_index] = value


revived_types.register_revived_type(
    "trackable_list_wrapper",
Example #5
0
    @property
    def trainable_variables(self):
        return self._trainable_variables


# Registers the Snapshot object above such that when it is restored by
# tf.saved_model.load it will be restored as a Snapshot. This is important
# because it allows us to expose the __call__, and *_variables properties.
revived_types.register_revived_type(
    'acme_snapshot',
    lambda obj: isinstance(obj, Snapshot),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda proto: Snapshot(),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=setattr,
        )
    ])


def make_snapshot(module: snt.Module):
    """Create a thin wrapper around a module to make it snapshottable."""
    # Get the input signature as long as it has been created.
    input_signature = _get_input_signature(module)
    if input_signature is None:
        raise ValueError(
            ('module instance "{}" has no input_signature attribute, '
             'which is required for snapshotting; run '
             'create_variables to add this annotation.').format(module.name))
Example #6
0
  optimizer object iself (e.g. through `apply_gradients`).
  """

    # TODO(allenl): Make the restored optimizer functional by tracing its apply
    # methods.

    def __init__(self):
        super(RestoredOptimizer, self).__init__("RestoredOptimizer")
        self._hypers_created = True

    def get_config(self):
        # TODO(allenl): Save and restore the Optimizer's config
        raise NotImplementedError(
            "Restoring functional Optimzers from SavedModels is not currently "
            "supported. Please file a feature request if this limitation bothers "
            "you.")


revived_types.register_revived_type(
    "optimizer",
    lambda obj: isinstance(obj, OptimizerV2),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda proto: RestoredOptimizer(),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
        )
    ])
                     partition_strategy='mod',
                     name=None,
                     validate_indices=True,
                     max_norm=None):
    if isinstance(params, list):
        params = params[0]
    return embedding_ops.embedding_lookup(params.variables, ids,
                                          partition_strategy, name,
                                          validate_indices, max_norm)


def _raise_when_load(_):
    # We don't have serialization and deserialization mechanisms for
    # `ShardedVariable` in 2.x style save/load yet.
    raise ValueError(
        'Loading a saved_model containing ShardedVariable via '
        '`tf.saved_model.load` is not supported. If the model is built using '
        'Keras, please use `tf.keras.models.load_model` instead.')


revived_types.register_revived_type(
    '_tf_distribute_sharded_variable',
    lambda obj: isinstance(obj, ShardedVariable),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=_raise_when_load,
            version=0,
            min_producer_version=0,
            min_consumer_version=0)
    ])
            # so we'll throw an exception on save.
            self._non_append_mutation = True
        del self._storage[key]
        self._update_snapshot()

    def __repr__(self):
        return "DictWrapper(%s)" % (repr(self._storage), )

    def __hash__(self):
        raise TypeError("unhashable type: 'DictWrapper'")

    def __eq__(self, other):
        return self._storage == getattr(other, "_storage", other)

    def update(self, *args, **kwargs):
        for key, value in dict(*args, **kwargs).items():
            self[key] = value


revived_types.register_revived_type(
    "checkpointable_dict_wrapper",
    lambda obj: isinstance(obj, _DictWrapper),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: _DictWrapper({}),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=operator.setitem)
    ])
Example #9
0
  def __eq__(self, other):
    return self._storage == getattr(other, "_storage", other)

  def update(self, *args, **kwargs):
    for key, value in dict(*args, **kwargs).items():
      self[key] = value

revived_types.register_revived_type(
    "checkpointable_dict_wrapper",
    lambda obj: isinstance(obj, _DictWrapper),
    versions=[revived_types.VersionedTypeRegistration(
        # Standard dependencies are enough to reconstruct the checkpointable
        # items in dictionaries, so we don't need to save any extra information.
        object_factory=lambda proto: _DictWrapper({}),
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
        setter=operator.setitem,
        getter=_DictWrapper.get,
        attribute_extractor=lambda obj: obj.keys())])


def _set_list_item(list_object, index_string, value):
  item_index = int(index_string)
  if len(list_object) <= item_index:
    list_object.extend([None] * (1 + item_index - len(list_object)))
  list_object[item_index] = value


def _list_getter(obj, item, default=None):
  index = int(item)