def _dump_examples_to_tfrecord(path, examples): """Writes list of example dicts to a TFRecord file of tf.Example protos.""" logging.info("Writing examples to TFRecord: %s", path) with tf.io.TFRecordWriter(path) as writer: for ex in examples: writer.write( dataset_utils.dict_to_tfexample(ex).SerializeToString())
def test_dict_to_tfexample(self): tfe = utils.dict_to_tfexample({ "inputs": "this is an input", "targets": "this is a target", "weight": 5.0, "idx1": np.array([1, 2], np.int32), "idx2": np.array([3, 4], np.int64), "is_correct": False, }) self.assertLen(tfe.features.feature, 6) self.assertEqual(tfe.features.feature["inputs"].bytes_list.value, [b"this is an input"]) self.assertEqual(tfe.features.feature["targets"].bytes_list.value, [b"this is a target"]) self.assertEqual(tfe.features.feature["weight"].float_list.value, [5.0]) np.testing.assert_array_equal( tfe.features.feature["idx1"].int64_list.value, np.array([1, 2], np.int64)) np.testing.assert_array_equal( tfe.features.feature["idx2"].int64_list.value, np.array([3, 4], np.int64)) np.testing.assert_array_equal( tfe.features.feature["is_correct"].int64_list.value, np.array([0], np.int64))