def test_dict_to_tf_example_error_reraise(self):
     # Test error reraise in _dict_to_tf_example.
     example_data = {'input': [1, 2, 3]}
     tensor_info = {
         'input': feature_lib.TensorInfo(
             shape=(2, ),
             dtype=tf.int64,
         ),
     }
     with self.assertRaisesRegex(
             ValueError, 'Error while serializing feature `input`:'):
         example_serializer._dict_to_tf_example(example_data, tensor_info)
 def test_ragged_dict_to_tf_example_empty(self):
     example_data = {
         'input': [],
     }
     tensor_info = {
         'input':
         feature_lib.TensorInfo(
             shape=(
                 None,
                 None,
             ),
             dtype=tf.int64,
             sequence_rank=2,
         ),
     }
     ex_proto = example_serializer._dict_to_tf_example(
         example_data, tensor_info)
     feature = ex_proto.features.feature
     self.assertEqual(
         [],
         list(feature['input/ragged_flat_values'].int64_list.value),
     )
     self.assertEqual(
         [],
         list(feature['input/ragged_row_lengths_0'].int64_list.value),
     )
    def test_dict_to_example(self):
        example = example_serializer._dict_to_tf_example({
            "a":
            1,
            "a2":
            np.array(1),
            "b": ["foo", "bar"],
            "b2":
            np.array(["foo", "bar"]),
            "c": [2.0],
            "c2":
            np.array([2.0]),
            # Empty values supported when type is defined
            "d":
            np.array([], dtype=np.int32),
            # Support for byte strings
            "e":
            np.zeros(2, dtype=np.uint8).tobytes(),
            "e2": [np.zeros(2, dtype=np.uint8).tobytes()] * 2,
        })
        feature = example.features.feature
        self.assertEqual([1], list(feature["a"].int64_list.value))
        self.assertEqual([1], list(feature["a2"].int64_list.value))
        self.assertEqual([b"foo", b"bar"], list(feature["b"].bytes_list.value))
        self.assertEqual([b"foo", b"bar"],
                         list(feature["b2"].bytes_list.value))
        self.assertEqual([2.0], list(feature["c"].float_list.value))
        self.assertEqual([2.0], list(feature["c2"].float_list.value))
        self.assertEqual([], list(feature["d"].int64_list.value))
        self.assertEqual([b"\x00\x00"], list(feature["e"].bytes_list.value))
        self.assertEqual([b"\x00\x00", b"\x00\x00"],
                         list(feature["e2"].bytes_list.value))

        with self.assertRaisesWithPredicateMatch(ValueError,
                                                 "Received an empty"):
            # Raise error if an undefined empty value is given
            example_serializer._dict_to_tf_example({
                "empty": [],
            })

        with self.assertRaisesWithPredicateMatch(ValueError,
                                                 "not support type"):
            # Raise error if an unsupported dtype is given
            example_serializer._dict_to_tf_example({
                "wrong_type":
                np.zeros(shape=(5, ), dtype=np.complex64),
            })