コード例 #1
0
  def do_decode(self, value, decode_fn):
    """Returns the `tf.TypeSpec` encoded by the proto `value`."""
    type_spec_proto = value.type_spec_value
    type_spec_class_enum = type_spec_proto.type_spec_class
    class_name = type_spec_proto.type_spec_class_name

    if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC:
      try:
        type_spec_class = type_spec.lookup(class_name)
      except ValueError as e:
        raise ValueError(
            f"The type '{class_name}' has not been registered.  It must be "
            "registered before you load this object (typically by importing "
            "its module).") from e
    elif type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC:
      try:
        type_spec_class = type_spec.lookup(class_name)
      except ValueError:
        type_spec_class = extension_type.AnonymousExtensionTypeSpec
        warnings.warn("The type %r has not been registered.  Falling back to "
                      "using AnonymousExtensionTypeSpec instead.")
    else:
      if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO:
        raise ValueError(
            f"The type '{class_name}' is not supported by this version of "
            "TensorFlow. (The object you are loading must have been created "
            "with a newer version of TensorFlow.)")
      type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum]

    # pylint: disable=protected-access
    return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
コード例 #2
0
 def testRegistryLookupErrors(self):
     with self.assertRaises(TypeError):
         type_spec.lookup(None)
     with self.assertRaisesRegex(
             ValueError,
             "No TypeSpec has been registered with name 'foo.bar'"):
         type_spec.lookup("foo.bar")
コード例 #3
0
 def testRegistry(self):
   self.assertEqual("tf.TwoCompositesSpec",
                    type_spec.get_name(TwoCompositesSpec))
   self.assertEqual("tf.TwoTensorsSpec", type_spec.get_name(TwoTensorsSpec))
   self.assertEqual(TwoCompositesSpec,
                    type_spec.lookup("tf.TwoCompositesSpec"))
   self.assertEqual(TwoTensorsSpec, type_spec.lookup("tf.TwoTensorsSpec"))
コード例 #4
0
    def testLoadSavedModelWithUnregisteredStruct(self):
        MaskedTensor = build_simple_masked_tensor_type()

        def f(x, y):
            x_values = x.values if isinstance(x, MaskedTensor) else x
            y_values = y.values if isinstance(y, MaskedTensor) else y
            x_mask = x.mask if isinstance(x, MaskedTensor) else True
            y_mask = y.mask if isinstance(y, MaskedTensor) else True
            return MaskedTensor(x_values + y_values, x_mask & y_mask)

        t_spec = tensor_spec.TensorSpec(None, dtypes.int32)
        b_spec = tensor_spec.TensorSpec(None, dtypes.bool)
        mt_spec = MaskedTensor.Spec(values=t_spec, mask=b_spec)
        model = module.Module()
        model.f = def_function.function(f)
        model.f.get_concrete_function(t_spec, t_spec)
        model.f.get_concrete_function(t_spec, mt_spec)
        model.f.get_concrete_function(mt_spec, t_spec)
        model.f.get_concrete_function(mt_spec, mt_spec)

        path = tempfile.mkdtemp(prefix=test.get_temp_dir())
        with temporarily_register_type_spec('tf.test.MaskedTensor.Spec',
                                            MaskedTensor.Spec):
            save.save(model, path)
        loaded_model = load.load(path)

        with self.assertRaises(ValueError):
            type_spec.lookup('tf.test.MaskedTensor')

        t = constant_op.constant([10, 20, 30])
        v1 = loaded_model.f(t, t)
        self.assertIsInstance(v1, tensor_struct.AnonymousStruct)
        self.assertAllEqual(v1.values, [20, 40, 60])
        self.assertAllEqual(v1.mask, True)

        v2 = loaded_model.f(v1, v1)
        self.assertIsInstance(v2, tensor_struct.AnonymousStruct)
        self.assertAllEqual(v2.values, [40, 80, 120])
        self.assertAllEqual(v2.mask, True)

        mt = MaskedTensor([1, 2, 3], [True, True, False])
        v3 = loaded_model.f(
            t,
            tensor_struct.reinterpret_struct(mt,
                                             tensor_struct.AnonymousStruct))
        self.assertIsInstance(v3, tensor_struct.AnonymousStruct)
        self.assertAllEqual(v3.values, [11, 22, 33])
        self.assertAllEqual(v3.mask, [True, True, False])

        v4 = tensor_struct.reinterpret_struct(v3, MaskedTensor)
        self.assertIsInstance(v4, MaskedTensor)
        self.assertAllEqual(v4.values, [11, 22, 33])
        self.assertAllEqual(v4.mask, [True, True, False])
コード例 #5
0
    def do_decode(self, value, decode_fn):
        """Returns the `tf.TypeSpec` encoded by the proto `value`."""
        type_spec_proto = value.type_spec_value
        type_spec_class_enum = type_spec_proto.type_spec_class
        class_name = type_spec_proto.type_spec_class_name

        if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC:
            try:
                type_spec_class = type_spec.lookup(class_name)
            except ValueError as e:
                raise ValueError(
                    "The type '%s' has not been registered.  It must be registered "
                    "before you load this object (typically by importing its module)."
                    % class_name) from e
        else:
            if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO:
                raise ValueError(
                    "The type '%s' is not supported by this version of TensorFlow. "
                    "(The object you are loading must have been created with a newer "
                    "version of TensorFlow.)" % class_name)
            type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[
                type_spec_class_enum]

        # pylint: disable=protected-access
        return type_spec_class._deserialize(
            decode_fn(type_spec_proto.type_state))
コード例 #6
0
def _decode_helper(
    obj, deserialize=False, module_objects=None, custom_objects=None
):
    """A decoding helper that is TF-object aware.

    Args:
      obj: A decoded dictionary that may represent an object.
      deserialize: Boolean, defaults to False. When True, deserializes any Keras
        objects found in `obj`.
      module_objects: A dictionary of built-in objects to look the name up in.
        Generally, `module_objects` is provided by midlevel library implementers.
      custom_objects: A dictionary of custom objects to look the name up in.
        Generally, `custom_objects` is provided by the end user.

    Returns:
      The decoded object.
    """
    if isinstance(obj, dict) and "class_name" in obj:
        if obj["class_name"] == "TensorShape":
            return tf.TensorShape(obj["items"])
        elif obj["class_name"] == "TypeSpec":
            return type_spec.lookup(
                obj["type_spec"]
            )._deserialize(  # pylint: disable=protected-access
                _decode_helper(obj["serialized"])
            )
        elif obj["class_name"] == "CompositeTensor":
            spec = obj["spec"]
            tensors = []
            for dtype, tensor in obj["tensors"]:
                tensors.append(
                    tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype))
                )
            return tf.nest.pack_sequence_as(
                _decode_helper(spec), tensors, expand_composites=True
            )
        elif obj["class_name"] == "__tuple__":
            return tuple(_decode_helper(i) for i in obj["items"])
        elif obj["class_name"] == "__ellipsis__":
            return Ellipsis
        elif deserialize and "__passive_serialization__" in obj:
            # __passive_serialization__ is added by the JSON encoder when encoding
            # an object that has a `get_config()` method.
            try:
                return generic_utils.deserialize_keras_object(
                    obj,
                    module_objects=module_objects,
                    custom_objects=custom_objects,
                )
            except ValueError:
                pass
    return obj
コード例 #7
0
def _decode_helper(obj):
    """A decoding helper that is TF-object aware."""
    if isinstance(obj, dict) and 'class_name' in obj:
        if obj['class_name'] == 'TensorShape':
            return tf.TensorShape(obj['items'])
        elif obj['class_name'] == 'TypeSpec':
            return type_spec.lookup(obj['type_spec'])._deserialize(  # pylint: disable=protected-access
                _decode_helper(obj['serialized']))
        elif obj['class_name'] == '__tuple__':
            return tuple(_decode_helper(i) for i in obj['items'])
        elif obj['class_name'] == '__ellipsis__':
            return Ellipsis
    return obj
コード例 #8
0
ファイル: json_utils.py プロジェクト: ttigong/keras
def _decode_helper(obj,
                   deserialize=False,
                   module_objects=None,
                   custom_objects=None):
    """A decoding helper that is TF-object aware.

  Args:
    obj: A decoded dictionary that may represent an object.
    deserialize: Boolean, defaults to False. When True, deserializes any Keras
      objects found in `obj`.
    module_objects: A dictionary of built-in objects to look the name up in.
      Generally, `module_objects` is provided by midlevel library implementers.
    custom_objects: A dictionary of custom objects to look the name up in.
      Generally, `custom_objects` is provided by the end user.

  Returns:
    The decoded object.
  """
    if isinstance(obj, dict) and 'class_name' in obj:
        if obj['class_name'] == 'TensorShape':
            return tf.TensorShape(obj['items'])
        elif obj['class_name'] == 'TypeSpec':
            return type_spec.lookup(obj['type_spec'])._deserialize(  # pylint: disable=protected-access
                _decode_helper(obj['serialized']))
        elif obj['class_name'] == '__tuple__':
            return tuple(_decode_helper(i) for i in obj['items'])
        elif obj['class_name'] == '__ellipsis__':
            return Ellipsis
        elif deserialize and '__passive_serialization__' in obj:
            # __passive_serialization__ is added by the JSON encoder when encoding
            # an object that has a `get_config()` method.
            try:
                return generic_utils.deserialize_keras_object(
                    obj,
                    module_objects=module_objects,
                    custom_objects=custom_objects)
            except ValueError:
                pass
    return obj
コード例 #9
0
def auto_composite_tensor(cls=None, omit_kwargs=(), module_name=None):
  """Automagically generate `CompositeTensor` behavior for `cls`.

  `CompositeTensor` objects are able to pass in and out of `tf.function` and
  `tf.while_loop`, or serve as part of the signature of a TF saved model.

  The contract of `auto_composite_tensor` is that all __init__ args and kwargs
  must have corresponding public or private attributes (or properties). Each of
  these attributes is inspected (recursively) to determine whether it is (or
  contains) `Tensor`s or non-`Tensor` metadata. `list` and `tuple` attributes
  are supported, but must either contain *only* `Tensor`s (or lists, etc,
  thereof), or *no* `Tensor`s. E.g.,
    - object.attribute = [1., 2., 'abc']                        # valid
    - object.attribute = [tf.constant(1.), [tf.constant(2.)]]   # valid
    - object.attribute = ['abc', tf.constant(1.)]               # invalid

  If the attribute is a callable, serialization of the `TypeSpec`, and therefore
  interoperability with `tf.saved_model`, is not currently supported. As a
  workaround, callables that do not contain or close over `Tensor`s may be
  expressed as functors that subclass `AutoCompositeTensor` and used in place of
  the original callable arg:

  ```python
  @auto_composite_tensor(module_name='my.module')
  class F(AutoCompositeTensor):

    def __call__(self, *args, **kwargs):
      return original_callable(*args, **kwargs)
  ```

  Callable objects that do contain or close over `Tensor`s should either
  (1) subclass `AutoCompositeTensor`, with the `Tensor`s passed to the
  constructor, (2) subclass `CompositeTensor` and implement their own
  `TypeSpec`, or (3) have a conversion function registered with
  `type_spec.register_type_spec_from_value_converter`.

  If the object has a `_composite_tensor_shape_parameters` field (presumed to
  have `tuple` of `str` value), the flattening code will use
  `tf.get_static_value` to attempt to preserve shapes as static metadata, for
  fields whose name matches a name specified in that field. Preserving static
  values can be important to correctly propagating shapes through a loop.
  Note that the Distribution and Bijector base classes provide a
  default implementation of `_composite_tensor_shape_parameters`, populated by
  `parameter_properties` annotations.

  If the decorated class `A` does not subclass `CompositeTensor`, a *new class*
  will be generated, which mixes in `A` and `CompositeTensor`.

  To avoid this extra class in the class hierarchy, we suggest inheriting from
  `auto_composite_tensor.AutoCompositeTensor`, which inherits from
  `CompositeTensor` and implants a trivial `_type_spec` @property. The
  `@auto_composite_tensor` decorator will then overwrite this trivial
  `_type_spec` @property. The trivial one is necessary because `_type_spec` is
  an abstract property of `CompositeTensor`, and a valid class instance must be
  created before the decorator can execute -- without the trivial `_type_spec`
  property present, `ABCMeta` will throw an error! The user may thus do any of
  the following:

  #### `AutoCompositeTensor` base class (recommended)
  ```python
  @tfp.experimental.auto_composite_tensor
  class MyClass(tfp.experimental.AutoCompositeTensor):
    ...

  mc = MyClass()
  type(mc)
  # ==> MyClass
  ```

  #### No `CompositeTensor` base class (ok, but changes expected types)
  ```python
  @tfp.experimental.auto_composite_tensor
  class MyClass(object):
    ...

  mc = MyClass()
  type(mc)
  # ==> MyClass_AutoCompositeTensor
  ```

  #### `CompositeTensor` base class, requiring trivial `_type_spec`
  ```python
  from tensorflow.python.framework import composite_tensor
  @tfp.experimental.auto_composite_tensor
  class MyClass(composite_tensor.CompositeTensor):
    @property
    def _type_spec(self):  # will be overwritten by @auto_composite_tensor
      pass
    ...

  mc = MyClass()
  type(mc)
  # ==> MyClass
  ```

  ## Full usage example

  ```python
  @tfp.experimental.auto_composite_tensor(omit_kwargs=('name',))
  class Adder(tfp.experimental.AutoCompositeTensor):
    def __init__(self, x, y, name=None):
      with tf.name_scope(name or 'Adder') as name:
        self._x = tf.convert_to_tensor(x)
        self._y = tf.convert_to_tensor(y)
        self._name = name

    def xpy(self):
      return self._x + self._y

  def body(obj):
    return Adder(obj.xpy(), 1.),

  result, = tf.while_loop(
      cond=lambda _: True,
      body=body,
      loop_vars=(Adder(1., 1.),),
      maximum_iterations=3)

  result.xpy()  # => 5.
  ```

  Args:
    cls: The class for which to create a CompositeTensor subclass.
    omit_kwargs: Optional sequence of kwarg names to be omitted from the spec.
    module_name: The module name with which to register the `TypeSpec`. If
      `None`, defaults to `cls.__module__`.

  Returns:
    composite_tensor_subclass: A subclass of `cls` and TF CompositeTensor.
  """
  if cls is None:
    return functools.partial(auto_composite_tensor,
                             omit_kwargs=omit_kwargs,
                             module_name=module_name)

  if module_name is None:
    module_name = cls.__module__

  type_spec_class_name = f'{cls.__name__}_ACTTypeSpec'
  type_spec_name = f'{module_name}.{type_spec_class_name}'

  try:
    ts = type_spec.lookup(type_spec_name)
    return ts.value_type.fget(None)
  except ValueError:
    pass

  # If the declared class is already a CompositeTensor subclass, we can avoid
  # affecting the actual type of the returned class. Otherwise, we need to
  # explicitly mix in the CT type, and hence create and return a newly
  # synthesized type.
  if issubclass(cls, composite_tensor.CompositeTensor):

    @type_spec.register(type_spec_name)
    class _AlreadyCTTypeSpec(_AutoCompositeTensorTypeSpec):

      @property
      def value_type(self):
        return cls

    _AlreadyCTTypeSpec.__name__ = type_spec_class_name

    cls._type_spec = property(  # pylint: disable=protected-access
        lambda self: _AlreadyCTTypeSpec.from_instance(self, omit_kwargs))
    return cls

  clsid = (cls.__module__, cls.__name__, omit_kwargs)

  # Check for subclass if retrieving from the _registry, in case the user
  # has redefined the class (e.g. in a REPL/notebook).
  if clsid in _registry and issubclass(_registry[clsid], cls):
    return _registry[clsid]

  @type_spec.register(type_spec_name)
  class _GeneratedCTTypeSpec(_AutoCompositeTensorTypeSpec):

    @property
    def value_type(self):
      return _registry[clsid]

  _GeneratedCTTypeSpec.__name__ = type_spec_class_name

  class _AutoCompositeTensor(cls, composite_tensor.CompositeTensor):
    """A per-`cls` subclass of `CompositeTensor`."""

    @property
    def _type_spec(self):
      return _GeneratedCTTypeSpec.from_instance(self, omit_kwargs)

  _AutoCompositeTensor.__name__ = cls.__name__
  _registry[clsid] = _AutoCompositeTensor
  return _AutoCompositeTensor