def test_dict_to_tf_example_error_reraise(self): # Test error reraise in _dict_to_tf_example. example_data = {'input': [1, 2, 3]} tensor_info = { 'input': feature_lib.TensorInfo( shape=(2, ), dtype=tf.int64, ), } with self.assertRaisesRegex( ValueError, 'Error while serializing feature `input`:'): example_serializer._dict_to_tf_example(example_data, tensor_info)
def test_ragged_dict_to_tf_example_empty(self): example_data = { 'input': [], } tensor_info = { 'input': feature_lib.TensorInfo( shape=( None, None, ), dtype=tf.int64, sequence_rank=2, ), } ex_proto = example_serializer._dict_to_tf_example( example_data, tensor_info) feature = ex_proto.features.feature self.assertEqual( [], list(feature['input/ragged_flat_values'].int64_list.value), ) self.assertEqual( [], list(feature['input/ragged_row_lengths_0'].int64_list.value), )
def test_dict_to_example(self): example = example_serializer._dict_to_tf_example({ "a": 1, "a2": np.array(1), "b": ["foo", "bar"], "b2": np.array(["foo", "bar"]), "c": [2.0], "c2": np.array([2.0]), # Empty values supported when type is defined "d": np.array([], dtype=np.int32), # Support for byte strings "e": np.zeros(2, dtype=np.uint8).tobytes(), "e2": [np.zeros(2, dtype=np.uint8).tobytes()] * 2, }) feature = example.features.feature self.assertEqual([1], list(feature["a"].int64_list.value)) self.assertEqual([1], list(feature["a2"].int64_list.value)) self.assertEqual([b"foo", b"bar"], list(feature["b"].bytes_list.value)) self.assertEqual([b"foo", b"bar"], list(feature["b2"].bytes_list.value)) self.assertEqual([2.0], list(feature["c"].float_list.value)) self.assertEqual([2.0], list(feature["c2"].float_list.value)) self.assertEqual([], list(feature["d"].int64_list.value)) self.assertEqual([b"\x00\x00"], list(feature["e"].bytes_list.value)) self.assertEqual([b"\x00\x00", b"\x00\x00"], list(feature["e2"].bytes_list.value)) with self.assertRaisesWithPredicateMatch(ValueError, "Received an empty"): # Raise error if an undefined empty value is given example_serializer._dict_to_tf_example({ "empty": [], }) with self.assertRaisesWithPredicateMatch(ValueError, "not support type"): # Raise error if an unsupported dtype is given example_serializer._dict_to_tf_example({ "wrong_type": np.zeros(shape=(5, ), dtype=np.complex64), })