コード例 #1
0
 def test_register_twice(self):
   identifier = "foo"
   predicate = lambda x: isinstance(x, int)
   versions = [
       revived_types.VersionedTypeRegistration(
           object_factory=lambda _: 1,
           version=1, min_producer_version=1,
           min_consumer_version=1),
   ]
   revived_types.register_revived_type(identifier, predicate, versions)
   with self.assertRaisesRegex(AssertionError, "Duplicate registrations"):
     revived_types.register_revived_type(identifier, predicate, versions)
コード例 #2
0
    def _list_functions_for_serialization(self, unused_serialization_cache):
        return {
            key: value
            for key, value in self.items()
            if isinstance(value, (def_function.Function,
                                  defun.ConcreteFunction))
        }


revived_types.register_revived_type(
    "signature_map",
    lambda obj: isinstance(obj, _SignatureMap),
    versions=[
        revived_types.VersionedTypeRegistration(
            # Standard dependencies are enough to reconstruct the trackable
            # items in dictionaries, so we don't need to save any extra information.
            object_factory=lambda proto: _SignatureMap(),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=_SignatureMap._add_signature  # pylint: disable=protected-access
        )
    ])


def create_signature_map(signatures):
    """Creates an object containing `signatures`."""
    signature_map = _SignatureMap()
    for name, func in signatures.items():
        # This true of any signature that came from canonicalize_signatures. Here as
        # a sanity check on saving; crashing on load (e.g. in _add_signature) would
        # be more problematic in case future export changes violated these
コード例 #3
0
        self.version = version


revived_types.register_revived_type(
    "test_type",
    lambda obj: isinstance(obj, CustomTestClass),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(1),
            version=1,
            min_producer_version=1,
            min_consumer_version=1),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(2),
            version=2,
            min_producer_version=2,
            min_consumer_version=1),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(3),
            version=3,
            min_producer_version=3,
            min_consumer_version=2),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(4),
            version=4,
            min_producer_version=4,
            min_consumer_version=2,
            bad_consumers=[3]),
    ])


class RegistrationMatchingTest(test.TestCase):
コード例 #4
0
            key: value
            for key, value in self.items() if _is_function(value)
        }


def _is_function(x):
    return isinstance(x, (def_function.Function, defun.ConcreteFunction))


revived_types.register_revived_type(
    "trackable_dict_wrapper",
    lambda obj: isinstance(obj, _DictWrapper),
    versions=[
        revived_types.VersionedTypeRegistration(
            # Standard dependencies are enough to reconstruct the trackable
            # items in dictionaries, so we don't need to save any extra information.
            object_factory=lambda proto: _DictWrapper({}),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=operator.setitem)
    ])


def _set_list_item(list_object, index_string, value):
    item_index = int(index_string)
    if len(list_object) <= item_index:
        list_object.extend([None] * (1 + item_index - len(list_object)))
    list_object[item_index] = value

コード例 #5
0
        return self._variables

    @property
    def trainable_variables(self):
        return self._trainable_variables


# Registers the Snapshot object above such that when it is restored by
# tf.saved_model.load it will be restored as a Snapshot. This is important
# because it allows us to expose the __call__, and *_variables properties.
revived_types.register_revived_type(
    'acme_snapshot',
    lambda obj: isinstance(obj, Snapshot),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda proto: Snapshot(),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=setattr,
        )
    ])


def make_snapshot(module: snt.Module):
    """Create a thin wrapper around a module to make it snapshottable."""
    # Get the input signature as long as it has been created.
    input_signature = _get_input_signature(module)
    if input_signature is None:
        raise ValueError(
            ('module instance "{}" has no input_signature attribute, '
             'which is required for snapshotting; run '
コード例 #6
0
ファイル: optimizer_v2.py プロジェクト: zk8085454/tensorflow
  optimizer object iself (e.g. through `apply_gradients`).
  """

    # TODO(allenl): Make the restored optimizer functional by tracing its apply
    # methods.

    def __init__(self):
        super(RestoredOptimizer, self).__init__("RestoredOptimizer")
        self._hypers_created = True

    def get_config(self):
        # TODO(allenl): Save and restore the Optimizer's config
        raise NotImplementedError(
            "Restoring functional Optimzers from SavedModels is not currently "
            "supported. Please file a feature request if this limitation bothers "
            "you.")


revived_types.register_revived_type(
    "optimizer",
    lambda obj: isinstance(obj, OptimizerV2),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda proto: RestoredOptimizer(),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
        )
    ])
コード例 #7
0
                     partition_strategy='mod',
                     name=None,
                     validate_indices=True,
                     max_norm=None):
    if isinstance(params, list):
        params = params[0]
    return embedding_ops.embedding_lookup(params.variables, ids,
                                          partition_strategy, name,
                                          validate_indices, max_norm)


def _raise_when_load(_):
    # We don't have serialization and deserialization mechanisms for
    # `ShardedVariable` in 2.x style save/load yet.
    raise ValueError(
        'Loading a saved_model containing ShardedVariable via '
        '`tf.saved_model.load` is not supported. If the model is built using '
        'Keras, please use `tf.keras.models.load_model` instead.')


revived_types.register_revived_type(
    '_tf_distribute_sharded_variable',
    lambda obj: isinstance(obj, ShardedVariable),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=_raise_when_load,
            version=0,
            min_producer_version=0,
            min_consumer_version=0)
    ])
コード例 #8
0
    return {
        key: value for key, value in self.items()
        if _is_function(value)
    }


def _is_function(x):
  return isinstance(x, (def_function.Function, defun.ConcreteFunction))


revived_types.register_revived_type(
    "trackable_dict_wrapper",
    lambda obj: isinstance(obj, _DictWrapper),
    versions=[revived_types.VersionedTypeRegistration(
        # Standard dependencies are enough to reconstruct the trackable
        # items in dictionaries, so we don't need to save any extra information.
        object_factory=lambda proto: _DictWrapper({}),
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
        setter=operator.setitem)])


def _set_list_item(list_object, index_string, value):
  item_index = int(index_string)
  if len(list_object) <= item_index:
    list_object.extend([None] * (1 + item_index - len(list_object)))
  list_object[item_index] = value


revived_types.register_revived_type(
コード例 #9
0
            # so we'll throw an exception on save.
            self._non_append_mutation = True
        del self._storage[key]
        self._update_snapshot()

    def __repr__(self):
        return "DictWrapper(%s)" % (repr(self._storage), )

    def __hash__(self):
        raise TypeError("unhashable type: 'DictWrapper'")

    def __eq__(self, other):
        return self._storage == getattr(other, "_storage", other)

    def update(self, *args, **kwargs):
        for key, value in dict(*args, **kwargs).items():
            self[key] = value


revived_types.register_revived_type(
    "checkpointable_dict_wrapper",
    lambda obj: isinstance(obj, _DictWrapper),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: _DictWrapper({}),
            version=1,
            min_producer_version=1,
            min_consumer_version=1,
            setter=operator.setitem)
    ])
コード例 #10
0
  def __repr__(self):
    return "_SignatureMap({})".format(self._signatures)

  def _list_functions_for_serialization(self):
    return {
        key: value for key, value in self.items()
        if isinstance(value, (def_function.Function, defun.ConcreteFunction))
    }


revived_types.register_revived_type(
    "signature_map",
    lambda obj: isinstance(obj, _SignatureMap),
    versions=[revived_types.VersionedTypeRegistration(
        # Standard dependencies are enough to reconstruct the trackable
        # items in dictionaries, so we don't need to save any extra information.
        object_factory=lambda proto: _SignatureMap(),
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
        setter=_SignatureMap._add_signature  # pylint: disable=protected-access
    )])


def create_signature_map(signatures):
  """Creates an object containing `signatures`."""
  signature_map = _SignatureMap()
  for name, func in signatures.items():
    # This true of any signature that came from canonicalize_signatures. Here as
    # a sanity check on saving; crashing on load (e.g. in _add_signature) would
    # be more problematic in case future export changes violated these
    # assertions.
コード例 #11
0
class CustomTestClass(tracking.AutoTrackable):

  def __init__(self, version):
    self.version = version


revived_types.register_revived_type(
    "test_type",
    lambda obj: isinstance(obj, CustomTestClass),
    versions=[
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(1),
            version=1, min_producer_version=1,
            min_consumer_version=1),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(2),
            version=2, min_producer_version=2, min_consumer_version=1),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(3),
            version=3, min_producer_version=3, min_consumer_version=2),
        revived_types.VersionedTypeRegistration(
            object_factory=lambda _: CustomTestClass(4),
            version=4, min_producer_version=4, min_consumer_version=2,
            bad_consumers=[3]),
    ]
)


class RegistrationMatchingTest(test.TestCase):

  def test_save_typecheck(self):
    self.assertIs(revived_types.serialize(tracking.AutoTrackable()), None)
コード例 #12
0
ファイル: data_structures.py プロジェクト: zaazad/tensorflow
    raise TypeError("unhashable type: 'DictWrapper'")

  def __eq__(self, other):
    return self._storage == getattr(other, "_storage", other)

  def update(self, *args, **kwargs):
    for key, value in dict(*args, **kwargs).items():
      self[key] = value

revived_types.register_revived_type(
    "checkpointable_dict_wrapper",
    lambda obj: isinstance(obj, _DictWrapper),
    versions=[revived_types.VersionedTypeRegistration(
        # Standard dependencies are enough to reconstruct the checkpointable
        # items in dictionaries, so we don't need to save any extra information.
        object_factory=lambda proto: _DictWrapper({}),
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
        setter=operator.setitem,
        getter=_DictWrapper.get,
        attribute_extractor=lambda obj: obj.keys())])


def _set_list_item(list_object, index_string, value):
  item_index = int(index_string)
  if len(list_object) <= item_index:
    list_object.extend([None] * (1 + item_index - len(list_object)))
  list_object[item_index] = value

コード例 #13
0
ファイル: optimizer_v2.py プロジェクト: aritratony/tensorflow
  Holds slot variables and hyperparameters when an optimizer is restored from a
  SavedModel. These variables may be referenced in functions along with ops
  created by the original optimizer, but currently we do not support using the
  optimizer object iself (e.g. through `apply_gradients`).
  """
  # TODO(allenl): Make the restored optimizer functional by tracing its apply
  # methods.

  def __init__(self):
    super(_RestoredOptimizer, self).__init__("_RestoredOptimizer")
    self._hypers_created = True

  def get_config(self):
    # TODO(allenl): Save and restore the Optimizer's config
    raise NotImplementedError(
        "Restoring functional Optimzers from SavedModels is not currently "
        "supported. Please file a feature request if this limitation bothers "
        "you.")

revived_types.register_revived_type(
    "optimizer",
    lambda obj: isinstance(obj, OptimizerV2),
    versions=[revived_types.VersionedTypeRegistration(
        object_factory=lambda proto: _RestoredOptimizer(),
        version=1,
        min_producer_version=1,
        min_consumer_version=1,
        setter=_RestoredOptimizer._set_hyper  # pylint: disable=protected-access
    )])