def _add_type_spec(cls):
    """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`."""
    spec_name = cls.__name__ + '.Spec'
    spec_qualname = cls.__qualname__ + '.Spec'

    # Set __module__ explicitly as a dynamic created class has module='abc'
    # by default.
    spec_dict = {'value_type': cls, '__module__': cls.__module__}

    # Copy user-supplied customizations into the TypeSpec.
    user_spec = cls.__dict__.get('Spec', None)
    if user_spec is not None:
        for (name, value) in user_spec.__dict__.items():
            if extension_type_field.ExtensionTypeField.is_reserved_name(name):
                raise ValueError(f'TypeSpec {spec_qualname} uses reserved '
                                 f"name '{name}'.")
            if cls._tf_extension_type_has_field(name):  # pylint: disable=protected-access
                raise ValueError(
                    f"TypeSpec {spec_qualname} defines a variable '{name}'"
                    f' which shadows a field in {cls.__qualname__}')
            if name in ('__module__', '__dict__', '__weakref__'):
                continue

            spec_dict[name] = value

    if issubclass(cls, BatchableExtensionType):
        type_spec_base = BatchableExtensionTypeSpec
        if hasattr(
                cls,
                '__batch_encoder__') and '__batch_encoder__' not in spec_dict:
            spec_dict['__batch_encoder__'] = cls.__batch_encoder__
    else:
        type_spec_base = ExtensionTypeSpec
        if hasattr(cls,
                   '__batch_encoder__') or '__batch_encoder__' in spec_dict:
            raise ValueError('__batch_encoder__ should only be defined for '
                             'BatchableExtensionType classes.')

    # Build the TypeSpec and store it as a nested class inside `cls`.
    spec = type(spec_name, (type_spec_base, ), spec_dict)
    spec.__qualname__ = spec_qualname
    setattr(cls, 'Spec', spec)

    # Build a constructor for the TypeSpec class.
    if '__init__' in spec.__dict__:
        _wrap_user_constructor(spec)
    else:
        _build_spec_constructor(spec)

    cls.__abstractmethods__ -= {'_type_spec'}

    # If the user included an explicit `__name__` attribute, then use that to
    # register the TypeSpec (so it can be used in SavedModel signatures).
    if '__name__' in cls.__dict__:
        type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
Esempio n. 2
0
  def testRegistryDuplicateErrors(self):
    with self.assertRaisesRegex(
        ValueError, "Name tf.TwoCompositesSpec has already been registered "
        "for class __main__.TwoCompositesSpec."):

      @type_spec.register("tf.TwoCompositesSpec")  # pylint: disable=unused-variable
      class NewTypeSpec(TwoCompositesSpec):
        pass

    with self.assertRaisesRegex(
        ValueError, "Class __main__.TwoCompositesSpec has already been "
        "registered with name tf.TwoCompositesSpec"):
      type_spec.register("tf.NewName")(TwoCompositesSpec)
Esempio n. 3
0
def _add_type_spec(cls):
    """Creates a nested TypeSpec class for tf.Struct subclass `cls`."""
    # Build the TypeSpec class for this struct type, and add it as a
    # nested class.
    spec_name = cls.__name__ + '.Spec'
    spec_dict = {'value_type': cls}
    spec = type(spec_name, (StructSpec, ), spec_dict)
    setattr(cls, 'Spec', spec)

    # Build a constructor for the TypeSpec class.
    _build_spec_constructor(spec)

    cls.__abstractmethods__ -= {'_type_spec'}

    # If the user included an explicit `__name__` attribute, then use that to
    # register the TypeSpec (so it can be used in SavedModel signatures).
    if '__name__' in cls.__dict__:
        type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
Esempio n. 4
0
  def testRegistryTypeErrors(self):
    with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"):
      type_spec.register(None)

    with self.assertRaisesRegex(TypeError, "Expected `name` to be a string"):
      type_spec.register(TwoTensorsSpec)

    with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"):
      type_spec.register("tf.foo")(None)

    with self.assertRaisesRegex(TypeError, "Expected `cls` to be a TypeSpec"):
      type_spec.register("tf.foo")(ragged_tensor.RaggedTensor)
Esempio n. 5
0
def _add_type_spec(cls):
    """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`."""
    # Build the TypeSpec class for this ExtensionType, and add it as a
    # nested class.
    spec_name = cls.__name__ + '.Spec'
    # Set __module__ explicitly as a dynamic created class has module='abc'
    # by default.
    spec_dict = {'value_type': cls, '__module__': cls.__module__}
    spec = type(spec_name, (ExtensionTypeSpec, ), spec_dict)
    spec.__qualname__ = cls.__qualname__ + '.Spec'
    setattr(cls, 'Spec', spec)

    # Build a constructor for the TypeSpec class.
    _build_spec_constructor(spec)

    cls.__abstractmethods__ -= {'_type_spec'}

    # If the user included an explicit `__name__` attribute, then use that to
    # register the TypeSpec (so it can be used in SavedModel signatures).
    if '__name__' in cls.__dict__:
        type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
def type_spec_register(name, allow_overwrite=True):
    """Decorator used to register a unique name for a TypeSpec subclass.

  Unlike TensorFlow's `type_spec.register`, this function allows a new
  `TypeSpec` to be registered with a `name` that already appears in the
  registry (overwriting the `TypeSpec` already registered with that name). This
  allows for re-definition of `AutoCompositeTensor` subclasses in test
  environments and iPython.

  Args:
    name: The name of the type spec. Must have the form
    `"{project_name}.{type_name}"`.  E.g. `"my_project.MyTypeSpec"`.
    allow_overwrite: `bool`, if `True` then the entry in the `TypeSpec` registry
      keyed by `name` will be overwritten if it exists. If `False`, then
      behavior is the same as `type_spec.register`.

  Returns:
    A class decorator that registers the decorated class with the given name.
  """
    # pylint: disable=protected-access
    if allow_overwrite and name in type_spec._NAME_TO_TYPE_SPEC:
        type_spec._TYPE_SPEC_TO_NAME.pop(
            type_spec._NAME_TO_TYPE_SPEC.pop(name))
    return type_spec.register(name)
Esempio n. 7
0
def temporarily_register_type_spec(name, cls):
    """Context manager for making temporary changes to the TypeSpec registry."""
    type_spec.register(name)(cls)
    yield
    assert type_spec._TYPE_SPEC_TO_NAME.pop(cls) == name
    assert type_spec._NAME_TO_TYPE_SPEC.pop(name) is cls
Esempio n. 8
0
 def testRegistryNameErrors(self):
     for bad_name in ["foo", "", "hello world"]:
         with self.assertRaises(ValueError):
             type_spec.register(bad_name)