def testConvertFieldsMismatch(self, field_values, error): fields = [ struct_field.StructField('x', int), struct_field.StructField('y', float) ] with self.assertRaisesRegex(ValueError, error): struct_field.convert_fields(fields, field_values)
def testConvertFields(self): fields = [ struct_field.StructField('x', int), struct_field.StructField( 'y', typing.Tuple[typing.Union[int, bool], ...]), struct_field.StructField('z', ops.Tensor) ] field_values = { 'x': 1, 'y': [1, True, 3], 'z': [[1, 2], [3, 4], [5, 6]] } struct_field.convert_fields(fields, field_values) self.assertEqual(set(field_values), set(['x', 'y', 'z'])) self.assertEqual(field_values['x'], 1) self.assertEqual(field_values['y'], (1, True, 3)) self.assertIsInstance(field_values['z'], ops.Tensor) self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]])
def _tf_struct_convert_fields(self): struct_field.convert_fields(self._tf_struct_fields(), self.__dict__)