def do_decode(self, value, decode_fn): """Returns the `tf.TypeSpec` encoded by the proto `value`.""" type_spec_proto = value.type_spec_value type_spec_class_enum = type_spec_proto.type_spec_class class_name = type_spec_proto.type_spec_class_name if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC: try: type_spec_class = type_spec.lookup(class_name) except ValueError as e: raise ValueError( f"The type '{class_name}' has not been registered. It must be " "registered before you load this object (typically by importing " "its module).") from e elif type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC: try: type_spec_class = type_spec.lookup(class_name) except ValueError: type_spec_class = extension_type.AnonymousExtensionTypeSpec warnings.warn("The type %r has not been registered. Falling back to " "using AnonymousExtensionTypeSpec instead.") else: if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO: raise ValueError( f"The type '{class_name}' is not supported by this version of " "TensorFlow. (The object you are loading must have been created " "with a newer version of TensorFlow.)") type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum] # pylint: disable=protected-access return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
def testRegistryLookupErrors(self): with self.assertRaises(TypeError): type_spec.lookup(None) with self.assertRaisesRegex( ValueError, "No TypeSpec has been registered with name 'foo.bar'"): type_spec.lookup("foo.bar")
def testRegistry(self): self.assertEqual("tf.TwoCompositesSpec", type_spec.get_name(TwoCompositesSpec)) self.assertEqual("tf.TwoTensorsSpec", type_spec.get_name(TwoTensorsSpec)) self.assertEqual(TwoCompositesSpec, type_spec.lookup("tf.TwoCompositesSpec")) self.assertEqual(TwoTensorsSpec, type_spec.lookup("tf.TwoTensorsSpec"))
def testLoadSavedModelWithUnregisteredStruct(self): MaskedTensor = build_simple_masked_tensor_type() def f(x, y): x_values = x.values if isinstance(x, MaskedTensor) else x y_values = y.values if isinstance(y, MaskedTensor) else y x_mask = x.mask if isinstance(x, MaskedTensor) else True y_mask = y.mask if isinstance(y, MaskedTensor) else True return MaskedTensor(x_values + y_values, x_mask & y_mask) t_spec = tensor_spec.TensorSpec(None, dtypes.int32) b_spec = tensor_spec.TensorSpec(None, dtypes.bool) mt_spec = MaskedTensor.Spec(values=t_spec, mask=b_spec) model = module.Module() model.f = def_function.function(f) model.f.get_concrete_function(t_spec, t_spec) model.f.get_concrete_function(t_spec, mt_spec) model.f.get_concrete_function(mt_spec, t_spec) model.f.get_concrete_function(mt_spec, mt_spec) path = tempfile.mkdtemp(prefix=test.get_temp_dir()) with temporarily_register_type_spec('tf.test.MaskedTensor.Spec', MaskedTensor.Spec): save.save(model, path) loaded_model = load.load(path) with self.assertRaises(ValueError): type_spec.lookup('tf.test.MaskedTensor') t = constant_op.constant([10, 20, 30]) v1 = loaded_model.f(t, t) self.assertIsInstance(v1, tensor_struct.AnonymousStruct) self.assertAllEqual(v1.values, [20, 40, 60]) self.assertAllEqual(v1.mask, True) v2 = loaded_model.f(v1, v1) self.assertIsInstance(v2, tensor_struct.AnonymousStruct) self.assertAllEqual(v2.values, [40, 80, 120]) self.assertAllEqual(v2.mask, True) mt = MaskedTensor([1, 2, 3], [True, True, False]) v3 = loaded_model.f( t, tensor_struct.reinterpret_struct(mt, tensor_struct.AnonymousStruct)) self.assertIsInstance(v3, tensor_struct.AnonymousStruct) self.assertAllEqual(v3.values, [11, 22, 33]) self.assertAllEqual(v3.mask, [True, True, False]) v4 = tensor_struct.reinterpret_struct(v3, MaskedTensor) self.assertIsInstance(v4, MaskedTensor) self.assertAllEqual(v4.values, [11, 22, 33]) self.assertAllEqual(v4.mask, [True, True, False])
def do_decode(self, value, decode_fn): """Returns the `tf.TypeSpec` encoded by the proto `value`.""" type_spec_proto = value.type_spec_value type_spec_class_enum = type_spec_proto.type_spec_class class_name = type_spec_proto.type_spec_class_name if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC: try: type_spec_class = type_spec.lookup(class_name) except ValueError as e: raise ValueError( "The type '%s' has not been registered. It must be registered " "before you load this object (typically by importing its module)." % class_name) from e else: if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO: raise ValueError( "The type '%s' is not supported by this version of TensorFlow. " "(The object you are loading must have been created with a newer " "version of TensorFlow.)" % class_name) type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[ type_spec_class_enum] # pylint: disable=protected-access return type_spec_class._deserialize( decode_fn(type_spec_proto.type_state))
def _decode_helper( obj, deserialize=False, module_objects=None, custom_objects=None ): """A decoding helper that is TF-object aware. Args: obj: A decoded dictionary that may represent an object. deserialize: Boolean, defaults to False. When True, deserializes any Keras objects found in `obj`. module_objects: A dictionary of built-in objects to look the name up in. Generally, `module_objects` is provided by midlevel library implementers. custom_objects: A dictionary of custom objects to look the name up in. Generally, `custom_objects` is provided by the end user. Returns: The decoded object. """ if isinstance(obj, dict) and "class_name" in obj: if obj["class_name"] == "TensorShape": return tf.TensorShape(obj["items"]) elif obj["class_name"] == "TypeSpec": return type_spec.lookup( obj["type_spec"] )._deserialize( # pylint: disable=protected-access _decode_helper(obj["serialized"]) ) elif obj["class_name"] == "CompositeTensor": spec = obj["spec"] tensors = [] for dtype, tensor in obj["tensors"]: tensors.append( tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype)) ) return tf.nest.pack_sequence_as( _decode_helper(spec), tensors, expand_composites=True ) elif obj["class_name"] == "__tuple__": return tuple(_decode_helper(i) for i in obj["items"]) elif obj["class_name"] == "__ellipsis__": return Ellipsis elif deserialize and "__passive_serialization__" in obj: # __passive_serialization__ is added by the JSON encoder when encoding # an object that has a `get_config()` method. try: return generic_utils.deserialize_keras_object( obj, module_objects=module_objects, custom_objects=custom_objects, ) except ValueError: pass return obj
def _decode_helper(obj): """A decoding helper that is TF-object aware.""" if isinstance(obj, dict) and 'class_name' in obj: if obj['class_name'] == 'TensorShape': return tf.TensorShape(obj['items']) elif obj['class_name'] == 'TypeSpec': return type_spec.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access _decode_helper(obj['serialized'])) elif obj['class_name'] == '__tuple__': return tuple(_decode_helper(i) for i in obj['items']) elif obj['class_name'] == '__ellipsis__': return Ellipsis return obj
def _decode_helper(obj, deserialize=False, module_objects=None, custom_objects=None): """A decoding helper that is TF-object aware. Args: obj: A decoded dictionary that may represent an object. deserialize: Boolean, defaults to False. When True, deserializes any Keras objects found in `obj`. module_objects: A dictionary of built-in objects to look the name up in. Generally, `module_objects` is provided by midlevel library implementers. custom_objects: A dictionary of custom objects to look the name up in. Generally, `custom_objects` is provided by the end user. Returns: The decoded object. """ if isinstance(obj, dict) and 'class_name' in obj: if obj['class_name'] == 'TensorShape': return tf.TensorShape(obj['items']) elif obj['class_name'] == 'TypeSpec': return type_spec.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access _decode_helper(obj['serialized'])) elif obj['class_name'] == '__tuple__': return tuple(_decode_helper(i) for i in obj['items']) elif obj['class_name'] == '__ellipsis__': return Ellipsis elif deserialize and '__passive_serialization__' in obj: # __passive_serialization__ is added by the JSON encoder when encoding # an object that has a `get_config()` method. try: return generic_utils.deserialize_keras_object( obj, module_objects=module_objects, custom_objects=custom_objects) except ValueError: pass return obj
def auto_composite_tensor(cls=None, omit_kwargs=(), module_name=None): """Automagically generate `CompositeTensor` behavior for `cls`. `CompositeTensor` objects are able to pass in and out of `tf.function` and `tf.while_loop`, or serve as part of the signature of a TF saved model. The contract of `auto_composite_tensor` is that all __init__ args and kwargs must have corresponding public or private attributes (or properties). Each of these attributes is inspected (recursively) to determine whether it is (or contains) `Tensor`s or non-`Tensor` metadata. `list` and `tuple` attributes are supported, but must either contain *only* `Tensor`s (or lists, etc, thereof), or *no* `Tensor`s. E.g., - object.attribute = [1., 2., 'abc'] # valid - object.attribute = [tf.constant(1.), [tf.constant(2.)]] # valid - object.attribute = ['abc', tf.constant(1.)] # invalid If the attribute is a callable, serialization of the `TypeSpec`, and therefore interoperability with `tf.saved_model`, is not currently supported. As a workaround, callables that do not contain or close over `Tensor`s may be expressed as functors that subclass `AutoCompositeTensor` and used in place of the original callable arg: ```python @auto_composite_tensor(module_name='my.module') class F(AutoCompositeTensor): def __call__(self, *args, **kwargs): return original_callable(*args, **kwargs) ``` Callable objects that do contain or close over `Tensor`s should either (1) subclass `AutoCompositeTensor`, with the `Tensor`s passed to the constructor, (2) subclass `CompositeTensor` and implement their own `TypeSpec`, or (3) have a conversion function registered with `type_spec.register_type_spec_from_value_converter`. If the object has a `_composite_tensor_shape_parameters` field (presumed to have `tuple` of `str` value), the flattening code will use `tf.get_static_value` to attempt to preserve shapes as static metadata, for fields whose name matches a name specified in that field. Preserving static values can be important to correctly propagating shapes through a loop. Note that the Distribution and Bijector base classes provide a default implementation of `_composite_tensor_shape_parameters`, populated by `parameter_properties` annotations. If the decorated class `A` does not subclass `CompositeTensor`, a *new class* will be generated, which mixes in `A` and `CompositeTensor`. To avoid this extra class in the class hierarchy, we suggest inheriting from `auto_composite_tensor.AutoCompositeTensor`, which inherits from `CompositeTensor` and implants a trivial `_type_spec` @property. The `@auto_composite_tensor` decorator will then overwrite this trivial `_type_spec` @property. The trivial one is necessary because `_type_spec` is an abstract property of `CompositeTensor`, and a valid class instance must be created before the decorator can execute -- without the trivial `_type_spec` property present, `ABCMeta` will throw an error! The user may thus do any of the following: #### `AutoCompositeTensor` base class (recommended) ```python @tfp.experimental.auto_composite_tensor class MyClass(tfp.experimental.AutoCompositeTensor): ... mc = MyClass() type(mc) # ==> MyClass ``` #### No `CompositeTensor` base class (ok, but changes expected types) ```python @tfp.experimental.auto_composite_tensor class MyClass(object): ... mc = MyClass() type(mc) # ==> MyClass_AutoCompositeTensor ``` #### `CompositeTensor` base class, requiring trivial `_type_spec` ```python from tensorflow.python.framework import composite_tensor @tfp.experimental.auto_composite_tensor class MyClass(composite_tensor.CompositeTensor): @property def _type_spec(self): # will be overwritten by @auto_composite_tensor pass ... mc = MyClass() type(mc) # ==> MyClass ``` ## Full usage example ```python @tfp.experimental.auto_composite_tensor(omit_kwargs=('name',)) class Adder(tfp.experimental.AutoCompositeTensor): def __init__(self, x, y, name=None): with tf.name_scope(name or 'Adder') as name: self._x = tf.convert_to_tensor(x) self._y = tf.convert_to_tensor(y) self._name = name def xpy(self): return self._x + self._y def body(obj): return Adder(obj.xpy(), 1.), result, = tf.while_loop( cond=lambda _: True, body=body, loop_vars=(Adder(1., 1.),), maximum_iterations=3) result.xpy() # => 5. ``` Args: cls: The class for which to create a CompositeTensor subclass. omit_kwargs: Optional sequence of kwarg names to be omitted from the spec. module_name: The module name with which to register the `TypeSpec`. If `None`, defaults to `cls.__module__`. Returns: composite_tensor_subclass: A subclass of `cls` and TF CompositeTensor. """ if cls is None: return functools.partial(auto_composite_tensor, omit_kwargs=omit_kwargs, module_name=module_name) if module_name is None: module_name = cls.__module__ type_spec_class_name = f'{cls.__name__}_ACTTypeSpec' type_spec_name = f'{module_name}.{type_spec_class_name}' try: ts = type_spec.lookup(type_spec_name) return ts.value_type.fget(None) except ValueError: pass # If the declared class is already a CompositeTensor subclass, we can avoid # affecting the actual type of the returned class. Otherwise, we need to # explicitly mix in the CT type, and hence create and return a newly # synthesized type. if issubclass(cls, composite_tensor.CompositeTensor): @type_spec.register(type_spec_name) class _AlreadyCTTypeSpec(_AutoCompositeTensorTypeSpec): @property def value_type(self): return cls _AlreadyCTTypeSpec.__name__ = type_spec_class_name cls._type_spec = property( # pylint: disable=protected-access lambda self: _AlreadyCTTypeSpec.from_instance(self, omit_kwargs)) return cls clsid = (cls.__module__, cls.__name__, omit_kwargs) # Check for subclass if retrieving from the _registry, in case the user # has redefined the class (e.g. in a REPL/notebook). if clsid in _registry and issubclass(_registry[clsid], cls): return _registry[clsid] @type_spec.register(type_spec_name) class _GeneratedCTTypeSpec(_AutoCompositeTensorTypeSpec): @property def value_type(self): return _registry[clsid] _GeneratedCTTypeSpec.__name__ = type_spec_class_name class _AutoCompositeTensor(cls, composite_tensor.CompositeTensor): """A per-`cls` subclass of `CompositeTensor`.""" @property def _type_spec(self): return _GeneratedCTTypeSpec.from_instance(self, omit_kwargs) _AutoCompositeTensor.__name__ = cls.__name__ _registry[clsid] = _AutoCompositeTensor return _AutoCompositeTensor