def test_dict_to_example(self):
        example = file_format_adapter._dict_to_tf_example({
            "a":
            1,
            "a2":
            np.array(1),
            "b": ["foo", "bar"],
            "b2":
            np.array(["foo", "bar"]),
            "c": [2.0],
            "c2":
            np.array([2.0]),
            # Empty values supported when type is defined
            "d":
            np.array([], dtype=np.int32),
            # Support for byte strings
            "e":
            np.zeros(2, dtype=np.uint8).tobytes(),
            "e2": [np.zeros(2, dtype=np.uint8).tobytes()] * 2,
        })
        feature = example.features.feature
        self.assertEqual([1], list(feature["a"].int64_list.value))
        self.assertEqual([1], list(feature["a2"].int64_list.value))
        self.assertEqual([b"foo", b"bar"], list(feature["b"].bytes_list.value))
        self.assertEqual([b"foo", b"bar"],
                         list(feature["b2"].bytes_list.value))
        self.assertEqual([2.0], list(feature["c"].float_list.value))
        self.assertEqual([2.0], list(feature["c2"].float_list.value))
        self.assertEqual([], list(feature["d"].int64_list.value))
        self.assertEqual([b"\x00\x00"], list(feature["e"].bytes_list.value))
        self.assertEqual([b"\x00\x00", b"\x00\x00"],
                         list(feature["e2"].bytes_list.value))

        with self.assertRaisesWithPredicateMatch(ValueError,
                                                 "Received an empty"):
            # Raise error if an undefined empty value is given
            file_format_adapter._dict_to_tf_example({
                "empty": [],
            })

        with self.assertRaisesWithPredicateMatch(ValueError,
                                                 "not support type"):
            # Raise error if an unsupported dtype is given
            file_format_adapter._dict_to_tf_example({
                "wrong_type":
                np.zeros(shape=(5, ), dtype=np.complex64),
            })
 def test_convert_to_example_generator(self):
     wrapped = file_format_adapter._generate_tf_examples(self.generator())
     expected = file_format_adapter._dict_to_tf_example(self.example_dict)
     wrapped_examples = list(wrapped)
     self.assertEqual(3, len(wrapped_examples))
     for serialized_example in wrapped_examples:
         example = tf.train.Example()
         example.ParseFromString(serialized_example)
         self.assertEqual(expected, example)
 def test_dict_to_example(self):
   example = file_format_adapter._dict_to_tf_example({
       "a": 1,
       "b": ["foo", "bar"],
       "c": [2.0],
   })
   feature = example.features.feature
   self.assertEqual([1], list(feature["a"].int64_list.value))
   self.assertEqual([b"foo", b"bar"], list(feature["b"].bytes_list.value))
   self.assertEqual([2.0], list(feature["c"].float_list.value))