def _add_type_spec(cls): """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`.""" spec_name = cls.__name__ + '.Spec' spec_qualname = cls.__qualname__ + '.Spec' # Set __module__ explicitly as a dynamic created class has module='abc' # by default. spec_dict = {'value_type': cls, '__module__': cls.__module__} # Copy user-supplied customizations into the TypeSpec. user_spec = cls.__dict__.get('Spec', None) if user_spec is not None: for (name, value) in user_spec.__dict__.items(): if extension_type_field.ExtensionTypeField.is_reserved_name(name): raise ValueError(f'TypeSpec {spec_qualname} uses reserved ' f"name '{name}'.") if cls._tf_extension_type_has_field(name): # pylint: disable=protected-access raise ValueError( f"TypeSpec {spec_qualname} defines a variable '{name}'" f' which shadows a field in {cls.__qualname__}') if name in ('__module__', '__dict__', '__weakref__'): continue spec_dict[name] = value if issubclass(cls, BatchableExtensionType): type_spec_base = BatchableExtensionTypeSpec if hasattr( cls, '__batch_encoder__') and '__batch_encoder__' not in spec_dict: spec_dict['__batch_encoder__'] = cls.__batch_encoder__ else: type_spec_base = ExtensionTypeSpec if hasattr(cls, '__batch_encoder__') or '__batch_encoder__' in spec_dict: raise ValueError('__batch_encoder__ should only be defined for ' 'BatchableExtensionType classes.') # Build the TypeSpec and store it as a nested class inside `cls`. spec = type(spec_name, (type_spec_base, ), spec_dict) spec.__qualname__ = spec_qualname setattr(cls, 'Spec', spec) # Build a constructor for the TypeSpec class. if '__init__' in spec.__dict__: _wrap_user_constructor(spec) else: _build_spec_constructor(spec) cls.__abstractmethods__ -= {'_type_spec'} # If the user included an explicit `__name__` attribute, then use that to # register the TypeSpec (so it can be used in SavedModel signatures). if '__name__' in cls.__dict__: type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
def testRegistryDuplicateErrors(self): with self.assertRaisesRegex( ValueError, "Name tf.TwoCompositesSpec has already been registered " "for class __main__.TwoCompositesSpec."): @type_spec.register("tf.TwoCompositesSpec") # pylint: disable=unused-variable class NewTypeSpec(TwoCompositesSpec): pass with self.assertRaisesRegex( ValueError, "Class __main__.TwoCompositesSpec has already been " "registered with name tf.TwoCompositesSpec"): type_spec.register("tf.NewName")(TwoCompositesSpec)
def _add_type_spec(cls): """Creates a nested TypeSpec class for tf.Struct subclass `cls`.""" # Build the TypeSpec class for this struct type, and add it as a # nested class. spec_name = cls.__name__ + '.Spec' spec_dict = {'value_type': cls} spec = type(spec_name, (StructSpec, ), spec_dict) setattr(cls, 'Spec', spec) # Build a constructor for the TypeSpec class. _build_spec_constructor(spec) cls.__abstractmethods__ -= {'_type_spec'} # If the user included an explicit `__name__` attribute, then use that to # register the TypeSpec (so it can be used in SavedModel signatures). if '__name__' in cls.__dict__: type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
def testRegistryTypeErrors(self): with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"): type_spec.register(None) with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"): type_spec.register(TwoTensorsSpec) with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"): type_spec.register("tf.foo")(None) with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"): type_spec.register("tf.foo")(ragged_tensor.RaggedTensor)
def _add_type_spec(cls): """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`.""" # Build the TypeSpec class for this ExtensionType, and add it as a # nested class. spec_name = cls.__name__ + '.Spec' # Set __module__ explicitly as a dynamic created class has module='abc' # by default. spec_dict = {'value_type': cls, '__module__': cls.__module__} spec = type(spec_name, (ExtensionTypeSpec, ), spec_dict) spec.__qualname__ = cls.__qualname__ + '.Spec' setattr(cls, 'Spec', spec) # Build a constructor for the TypeSpec class. _build_spec_constructor(spec) cls.__abstractmethods__ -= {'_type_spec'} # If the user included an explicit `__name__` attribute, then use that to # register the TypeSpec (so it can be used in SavedModel signatures). if '__name__' in cls.__dict__: type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
def type_spec_register(name, allow_overwrite=True): """Decorator used to register a unique name for a TypeSpec subclass. Unlike TensorFlow's `type_spec.register`, this function allows a new `TypeSpec` to be registered with a `name` that already appears in the registry (overwriting the `TypeSpec` already registered with that name). This allows for re-definition of `AutoCompositeTensor` subclasses in test environments and iPython. Args: name: The name of the type spec. Must have the form `"{project_name}.{type_name}"`. E.g. `"my_project.MyTypeSpec"`. allow_overwrite: `bool`, if `True` then the entry in the `TypeSpec` registry keyed by `name` will be overwritten if it exists. If `False`, then behavior is the same as `type_spec.register`. Returns: A class decorator that registers the decorated class with the given name. """ # pylint: disable=protected-access if allow_overwrite and name in type_spec._NAME_TO_TYPE_SPEC: type_spec._TYPE_SPEC_TO_NAME.pop( type_spec._NAME_TO_TYPE_SPEC.pop(name)) return type_spec.register(name)
def temporarily_register_type_spec(name, cls): """Context manager for making temporary changes to the TypeSpec registry.""" type_spec.register(name)(cls) yield assert type_spec._TYPE_SPEC_TO_NAME.pop(cls) == name assert type_spec._NAME_TO_TYPE_SPEC.pop(name) is cls
def testRegistryNameErrors(self): for bad_name in ["foo", "", "hello world"]: with self.assertRaises(ValueError): type_spec.register(bad_name)