Пример #1
0
 def testConvertFieldsMismatch(self, field_values, error):
     fields = [
         struct_field.StructField('x', int),
         struct_field.StructField('y', float)
     ]
     with self.assertRaisesRegex(ValueError, error):
         struct_field.convert_fields(fields, field_values)
Пример #2
0
 def testConvertFields(self):
     fields = [
         struct_field.StructField('x', int),
         struct_field.StructField(
             'y', typing.Tuple[typing.Union[int, bool], ...]),
         struct_field.StructField('z', ops.Tensor)
     ]
     field_values = {
         'x': 1,
         'y': [1, True, 3],
         'z': [[1, 2], [3, 4], [5, 6]]
     }
     struct_field.convert_fields(fields, field_values)
     self.assertEqual(set(field_values), set(['x', 'y', 'z']))
     self.assertEqual(field_values['x'], 1)
     self.assertEqual(field_values['y'], (1, True, 3))
     self.assertIsInstance(field_values['z'], ops.Tensor)
     self.assertAllEqual(field_values['z'], [[1, 2], [3, 4], [5, 6]])
Пример #3
0
 def _tf_struct_convert_fields(self):
     struct_field.convert_fields(self._tf_struct_fields(), self.__dict__)