def __call__(self, record_path, offset, label): rr = self.file_readers.get(record_path, None) if rr is None: rr = db.RecordReader(record_path) self.file_readers[record_path] = rr pb_data = rr.read_record(offset) example = example_pb2.Example() example.ParseFromString(pb_data) image_raw = example.features.feature['image'].bytes_list.value[0] image = cv2.imdecode(np.frombuffer(image_raw, np.uint8), cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.transform is not None: image = self.transform(image) return image, label
def test_reading_record(self): rr = db.RecordReader('test_utils/test-small-gzip-r00.tfrecords', db.Compression.ZLIB) self.assertIsNotNone(rr) file_size, data_size, entries = rr.get_metadata() self.assertEqual(entries, 50) records = list(rr) self.assertEqual(len(records), 50) # reading ground truth records to confirm reading container was correct with open('test_utils/test-small-records-gzip-r00.pth', 'rb') as f: records_gt = pickle.load(f) self.assertEqual(records_gt, records)
def test_reading_record_does_not_exist(self): with self.assertRaises(RuntimeError) as context: db.RecordReader('does_not_exist.tfrecords') self.assertEqual('Can\'t create RecordReader. Can\'t find file: does_not_exist.tfrecords', context.exception.args[0])
def simple_reading_of_records(): records = [] for filename in filenames: rr = db.RecordReader(filename) records += list(rr)