예제 #1
0
 def __init__(self, filenames, features, batch_size, buffer_size=1000, seed=None, epoch=0, compression=None):
     if seed is None:
         seed = np.uint64(time.time() * 1000)
     if compression is None:
         compression = db.Compression.NONE
     self.parser = db.RecordParser(features, True)
     self.record_yielder = db.ParsedRecordYielderRandomized(self.parser, filenames, buffer_size, seed, epoch, compression)
     self.batch_size = batch_size
예제 #2
0
    def test_parsing_records_in_batch(self):
        features_alternative = {
            'data': db.FixedLenFeature([3, 32, 32], db.uint8)
        }

        parser = db.RecordParser(features_alternative, False)
        self.assertIsNotNone(parser)

        data = parser.parse_example(self.records)[0]
        self.assertTrue(np.all(data == self.images_gt))
예제 #3
0
    def test_parsing_single_record_inplace_with_uint8(self):
        features_alternative = {
            'data': db.FixedLenFeature([3, 32, 32], db.uint8)
        }

        parser = db.RecordParser(features_alternative, False)
        self.assertIsNotNone(parser)

        data = np.zeros([3, 32, 32], dtype=np.uint8)

        for record, image_gt in zip(self.records, self.images_gt):
            parser.parse_single_example_inplace(record, [data], 0)
            self.assertTrue(np.all(data == image_gt))
예제 #4
0
    def test_parsing_single_record_raise2(self):
        features = {
            'data': db.FixedLenFeature([], db.int64)
        }

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        with self.assertRaises(RuntimeError) as context:
            for record, image_gt in zip(self.records, self.images_gt):
                shape, data = parser.parse_single_example(record)

        self.assertEqual('Feature: data. Data types don\'t match. Expected type: int64,  Feature is: string.', context.exception.args[0])
예제 #5
0
    def test_parsing_single_record_raise(self):
        features = {
            'does_not_exist': db.FixedLenFeature([], db.string)
        }

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        with self.assertRaises(RuntimeError) as context:
            for record, image_gt in zip(self.records, self.images_gt):
                shape, data = parser.parse_single_example(record)

        self.assertEqual('Feature does_not_exist is required but could not be found.', context.exception.args[0])
예제 #6
0
    def test_parsing_single_record_with_uint8_alternative_to_string_raise(self):
        features_alternative = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([3, 32, 64], db.uint8)
        }

        parser = db.RecordParser(features_alternative)
        self.assertIsNotNone(parser)

        with self.assertRaises(RuntimeError) as context:
            for record, image_gt in zip(self.records, self.images_gt):
                shape, data = parser.parse_single_example(record)

        self.assertEqual('Key: data. Number of uint8 values != expected. Values size: 3072 but output shape: [3, 32, 64].', context.exception.args[0])
예제 #7
0
    def test_parsing_single_record_with_uint8_alternative_to_string(self):
        features_alternative = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([3, 32, 32], db.uint8)
        }

        parser = db.RecordParser(features_alternative)
        self.assertIsNotNone(parser)

        for record, image_gt in zip(self.records, self.images_gt):
            shape, data = parser.parse_single_example(record)

            self.assertTrue(np.all(shape == [3, 32, 32]))
            self.assertTrue(np.all(data == image_gt))
예제 #8
0
 def test_yielder_randomized_parallel():
     features = {
         #'shape': db.FixedLenFeature([3], db.int64),
         'data': db.FixedLenFeature([3, 256, 256], db.uint8)
     }
     parser = db.RecordParser(features, True)
     record_yielder = db.ParsedRecordYielderRandomized(
         parser, filenames, 64, 1, 0)
     records = []
     while True:
         try:
             records += record_yielder.next_n(32)
         except StopIteration:
             break
예제 #9
0
    def test_parsing_single_record_inplace(self):
        features = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([], db.string)}

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        shape = np.zeros(3, dtype=np.int64)
        data = np.asarray([bytes()], dtype=object)

        for record, image_gt in zip(self.records, self.images_gt):
            parser.parse_single_example_inplace(record, [shape, data], 0)
            image = np.frombuffer(data[0], dtype=np.uint8).reshape(shape)
            self.assertTrue(np.all(image == image_gt))
예제 #10
0
    def test_parsing_single_record(self):
        features = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([], db.string)
        }

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        for record, image_gt in zip(self.records, self.images_gt):
            shape, data = parser.parse_single_example(record)

            self.assertTrue(np.all(shape == [3, 32, 32]))
            self.assertTrue(np.all(
                np.frombuffer(data[0], dtype=np.uint8).reshape(3, 32, 32) == image_gt
            ))