예제 #1
0
 def testConvertFieldsMismatch(self, field_values, error):
     fields = [
         struct_field.StructField('x', int),
         struct_field.StructField('y', float)
     ]
     with self.assertRaisesRegex(ValueError, error):
         struct_field.convert_fields(fields, field_values)
예제 #2
0
    def testForwardReferences(self):
        A, B = ForwardRefA, ForwardRefB

        self.assertEqual(A._tf_struct_fields(), (struct_field.StructField(
            'x', typing.Tuple[typing.Union[A, B],
                              ...]), struct_field.StructField('y', B)))
        self.assertEqual(B._tf_struct_fields(), (struct_field.StructField(
            'z', B), struct_field.StructField('n', ops.Tensor)))

        # Check the signature.
        expected_parameters = [
            tf_inspect.Parameter('self',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
            tf_inspect.Parameter(
                'x',
                tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                annotation=typing.Tuple[typing.Union['ForwardRefA',
                                                     'ForwardRefB'], ...]),
            tf_inspect.Parameter('y',
                                 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                 annotation='ForwardRefB'),
        ]
        expected_sig = tf_inspect.Signature(expected_parameters,
                                            return_annotation=A)
        self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
예제 #3
0
 def testStructFieldRepr(self,
                         expected,
                         name,
                         value_type,
                         default=struct_field.StructField.NO_DEFAULT):
     field = struct_field.StructField(name, value_type, default)
     self.assertEqual(repr(field), expected)
예제 #4
0
 def testConvertFieldsForSpec(self):
     fields = [
         struct_field.StructField('x', int),
         struct_field.StructField(
             'y', typing.Tuple[typing.Union[int, bool], ...]),
         struct_field.StructField('z', ops.Tensor)
     ]
     field_values = {
         'x': 1,
         'y': [1, True, 3],
         'z': tensor_spec.TensorSpec([5, 3])
     }
     struct_field.convert_fields_for_spec(fields, field_values)
     self.assertEqual(set(field_values), set(['x', 'y', 'z']))
     self.assertEqual(field_values['x'], 1)
     self.assertEqual(field_values['y'], (1, True, 3))
     self.assertEqual(field_values['z'], tensor_spec.TensorSpec([5, 3]))
예제 #5
0
 def testConvertFields(self):
     fields = [
         struct_field.StructField('x', int),
         struct_field.StructField(
             'y', typing.Tuple[typing.Union[int, bool], ...]),
         struct_field.StructField('z', ops.Tensor)
     ]
     field_values = {
         'x': 1,
         'y': [1, True, 3],
         'z': [[1, 2], [3, 4], [5, 6]]
     }
     struct_field.convert_fields(fields, field_values)
     self.assertEqual(set(field_values), set(['x', 'y', 'z']))
     self.assertEqual(field_values['x'], 1)
     self.assertEqual(field_values['y'], (1, True, 3))
     self.assertIsInstance(field_values['z'], ops.Tensor)
     self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]])
예제 #6
0
 def testStructFieldConstruction(
         self,
         name,
         value_type,
         default=struct_field.StructField.NO_DEFAULT,
         converted_default=None):
     if callable(default):
         default = default()  # deferred construction (contains tensor)
     field = struct_field.StructField(name, value_type, default)
     if converted_default is not None:
         default = converted_default
     self.assertEqual(field.name, name)
     self.assertEqual(field.value_type, value_type)
     if isinstance(field.default, (ops.Tensor, ragged_tensor.RaggedTensor)):
         self.assertAllEqual(field.default, default)
     else:
         self.assertEqual(field.default, default)
예제 #7
0
    def _tf_struct_fields(cls):  # pylint: disable=no-self-argument
        """An ordered list of `StructField`s describing the fields of this struct.

    Returns:
      A list of `StructField` objects.  Forward references are resolved if
      possible, or left unresolved otherwise.
    """
        if cls._tf_struct_cached_fields is not None:
            return cls._tf_struct_cached_fields

        try:
            type_hints = typing.get_type_hints(cls)
            ok_to_cache = True  # all forward references have been resolved.
        except (NameError, AttributeError):
            # Unresolved forward reference -- gather type hints manually.
            # * NameError comes from an annotation like `Foo` where class
            #   `Foo` hasn't been defined yet.
            # * AttributeError comes from an annotation like `foo.Bar`, where
            #   the module `foo` exists but `Bar` hasn't been defined yet.
            # Note: If a user attempts to instantiate a `Struct` type that still
            # has unresolved forward references (e.g., because of a typo or a
            # missing import), then the constructor will raise an exception.
            type_hints = {}
            for base in reversed(cls.__mro__):
                type_hints.update(base.__dict__.get('__annotations__', {}))
            ok_to_cache = False

        fields = []
        for (name, value_type) in type_hints.items():
            default = getattr(cls, name, struct_field.StructField.NO_DEFAULT)
            fields.append(struct_field.StructField(name, value_type, default))
        fields = tuple(fields)

        if ok_to_cache:
            cls._tf_struct_cached_fields = fields

        return fields
예제 #8
0
 def _tf_struct_fields(cls):
     return [
         struct_field.StructField(name, None) for name in cls.__dict__
         if not struct_field.StructField.is_reserved_name(name)
     ]
예제 #9
0
 def testStructFieldConstructionError(self, name, value_type, default,
                                      error):
     if callable(default):
         default = default()  # deferred construction (contains tensor)
     with self.assertRaisesRegex(TypeError, error):
         struct_field.StructField(name, value_type, default)