def __init__(self, filenames, features, batch_size, buffer_size=1000, seed=None, epoch=0, compression=None): if seed is None: seed = np.uint64(time.time() * 1000) if compression is None: compression = db.Compression.NONE self.parser = db.RecordParser(features, True) self.record_yielder = db.ParsedRecordYielderRandomized(self.parser, filenames, buffer_size, seed, epoch, compression) self.batch_size = batch_size
def test_parsing_records_in_batch(self): features_alternative = { 'data': db.FixedLenFeature([3, 32, 32], db.uint8) } parser = db.RecordParser(features_alternative, False) self.assertIsNotNone(parser) data = parser.parse_example(self.records)[0] self.assertTrue(np.all(data == self.images_gt))
def test_parsing_single_record_inplace_with_uint8(self): features_alternative = { 'data': db.FixedLenFeature([3, 32, 32], db.uint8) } parser = db.RecordParser(features_alternative, False) self.assertIsNotNone(parser) data = np.zeros([3, 32, 32], dtype=np.uint8) for record, image_gt in zip(self.records, self.images_gt): parser.parse_single_example_inplace(record, [data], 0) self.assertTrue(np.all(data == image_gt))
def test_parsing_single_record_raise2(self): features = { 'data': db.FixedLenFeature([], db.int64) } parser = db.RecordParser(features) self.assertIsNotNone(parser) with self.assertRaises(RuntimeError) as context: for record, image_gt in zip(self.records, self.images_gt): shape, data = parser.parse_single_example(record) self.assertEqual('Feature: data. Data types don\'t match. Expected type: int64, Feature is: string.', context.exception.args[0])
def test_parsing_single_record_raise(self): features = { 'does_not_exist': db.FixedLenFeature([], db.string) } parser = db.RecordParser(features) self.assertIsNotNone(parser) with self.assertRaises(RuntimeError) as context: for record, image_gt in zip(self.records, self.images_gt): shape, data = parser.parse_single_example(record) self.assertEqual('Feature does_not_exist is required but could not be found.', context.exception.args[0])
def test_parsing_single_record_with_uint8_alternative_to_string_raise(self): features_alternative = { 'shape': db.FixedLenFeature([3], db.int64), 'data': db.FixedLenFeature([3, 32, 64], db.uint8) } parser = db.RecordParser(features_alternative) self.assertIsNotNone(parser) with self.assertRaises(RuntimeError) as context: for record, image_gt in zip(self.records, self.images_gt): shape, data = parser.parse_single_example(record) self.assertEqual('Key: data. Number of uint8 values != expected. Values size: 3072 but output shape: [3, 32, 64].', context.exception.args[0])
def test_parsing_single_record_with_uint8_alternative_to_string(self): features_alternative = { 'shape': db.FixedLenFeature([3], db.int64), 'data': db.FixedLenFeature([3, 32, 32], db.uint8) } parser = db.RecordParser(features_alternative) self.assertIsNotNone(parser) for record, image_gt in zip(self.records, self.images_gt): shape, data = parser.parse_single_example(record) self.assertTrue(np.all(shape == [3, 32, 32])) self.assertTrue(np.all(data == image_gt))
def test_yielder_randomized_parallel(): features = { #'shape': db.FixedLenFeature([3], db.int64), 'data': db.FixedLenFeature([3, 256, 256], db.uint8) } parser = db.RecordParser(features, True) record_yielder = db.ParsedRecordYielderRandomized( parser, filenames, 64, 1, 0) records = [] while True: try: records += record_yielder.next_n(32) except StopIteration: break
def test_parsing_single_record_inplace(self): features = { 'shape': db.FixedLenFeature([3], db.int64), 'data': db.FixedLenFeature([], db.string)} parser = db.RecordParser(features) self.assertIsNotNone(parser) shape = np.zeros(3, dtype=np.int64) data = np.asarray([bytes()], dtype=object) for record, image_gt in zip(self.records, self.images_gt): parser.parse_single_example_inplace(record, [shape, data], 0) image = np.frombuffer(data[0], dtype=np.uint8).reshape(shape) self.assertTrue(np.all(image == image_gt))
def test_parsing_single_record(self): features = { 'shape': db.FixedLenFeature([3], db.int64), 'data': db.FixedLenFeature([], db.string) } parser = db.RecordParser(features) self.assertIsNotNone(parser) for record, image_gt in zip(self.records, self.images_gt): shape, data = parser.parse_single_example(record) self.assertTrue(np.all(shape == [3, 32, 32])) self.assertTrue(np.all( np.frombuffer(data[0], dtype=np.uint8).reshape(3, 32, 32) == image_gt ))