コード例 #1
0
ファイル: pytorch_loaders.py プロジェクト: seanpmorgan/armory
    def __init__(self, dataset_name, dataset_ver, split, epochs):
        self.data_files = locate_data(dataset_name, dataset_ver, split)
        self.features = {
            "label": db.FixedLenFeature([], db.int64),
            "image": db.FixedLenFeature([], db.string),
        }

        self.epochs = epochs
コード例 #2
0
    def test_parsing_single_record_with_uint8_alternative_to_string_raise(self):
        features_alternative = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([3, 32, 64], db.uint8)
        }

        parser = db.RecordParser(features_alternative)
        self.assertIsNotNone(parser)

        with self.assertRaises(RuntimeError) as context:
            for record, image_gt in zip(self.records, self.images_gt):
                shape, data = parser.parse_single_example(record)

        self.assertEqual('Key: data. Number of uint8 values != expected. Values size: 3072 but output shape: [3, 32, 64].', context.exception.args[0])
コード例 #3
0
    def test_parsing_single_record_with_uint8_alternative_to_string(self):
        features_alternative = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([3, 32, 32], db.uint8)
        }

        parser = db.RecordParser(features_alternative)
        self.assertIsNotNone(parser)

        for record, image_gt in zip(self.records, self.images_gt):
            shape, data = parser.parse_single_example(record)

            self.assertTrue(np.all(shape == [3, 32, 32]))
            self.assertTrue(np.all(data == image_gt))
コード例 #4
0
    def test_parsing_single_record_inplace(self):
        features = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([], db.string)}

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        shape = np.zeros(3, dtype=np.int64)
        data = np.asarray([bytes()], dtype=object)

        for record, image_gt in zip(self.records, self.images_gt):
            parser.parse_single_example_inplace(record, [shape, data], 0)
            image = np.frombuffer(data[0], dtype=np.uint8).reshape(shape)
            self.assertTrue(np.all(image == image_gt))
コード例 #5
0
    def test_parsing_single_record(self):
        features = {
            'shape': db.FixedLenFeature([3], db.int64),
            'data': db.FixedLenFeature([], db.string)
        }

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        for record, image_gt in zip(self.records, self.images_gt):
            shape, data = parser.parse_single_example(record)

            self.assertTrue(np.all(shape == [3, 32, 32]))
            self.assertTrue(np.all(
                np.frombuffer(data[0], dtype=np.uint8).reshape(3, 32, 32) == image_gt
            ))
コード例 #6
0
 def reset(self, lod, batch_size):
     assert lod in self.filenames.keys()
     self.current_filenames = self.filenames[lod]
     self.batch_size = batch_size
     img_size = 2**lod
     self.features = {
         'shape': db.FixedLenFeature([3], db.int64),
         'data': db.FixedLenFeature([3, img_size, img_size], db.uint8)
     }
     buffer_size = self.buffer_size_b // (3 * img_size * img_size)
     self.iterator = db.ParsedTFRecordsDatasetIterator(
         self.current_filenames,
         self.features,
         self.batch_size,
         buffer_size,
         seed=np.uint64(time.time() * 1000))
コード例 #7
0
 def reading_tf_records_from_dareblopy():
     features = {'data': db.FixedLenFeature([3, 256, 256], db.uint8)}
     iterator = db.data_loader(db.ParsedTFRecordsDatasetIterator(
         filenames, features, batch_size, 64),
                               worker_count=6)
     records = []
     for batch in iterator:
         records += batch
コード例 #8
0
 def reading_tf_records_from_dareblopy_withoutdecoding():
     features = {'data': db.FixedLenFeature([], db.string)}
     iterator = db.data_loader(db.ParsedTFRecordsDatasetIterator(
         filenames, features, batch_size, 128),
                               worker_count=6)
     records = []
     for batch in iterator:
         records += batch
コード例 #9
0
    def test_dataset_iterator(self):
        features = {
            'data': db.FixedLenFeature([3, 32, 32], db.uint8)
        }
        iterator = db.ParsedTFRecordsDatasetIterator(['test_utils/test-small-r00.tfrecords'],
                                                     features, 32, buffer_size=1)

        images = np.concatenate([x[0] for x in iterator], axis=0)
        self.assertTrue(np.all(images == self.images_gt))
コード例 #10
0
    def test_parsing_records_in_batch(self):
        features_alternative = {
            'data': db.FixedLenFeature([3, 32, 32], db.uint8)
        }

        parser = db.RecordParser(features_alternative, False)
        self.assertIsNotNone(parser)

        data = parser.parse_example(self.records)[0]
        self.assertTrue(np.all(data == self.images_gt))
コード例 #11
0
 def test_ParsedTFRecordsDatasetIterator():
     features = {
         #'shape': db.FixedLenFeature([3], db.int64),
         'data': db.FixedLenFeature([3, 256, 256], db.uint8)
     }
     iterator = db.ParsedTFRecordsDatasetIterator(filenames, features, 32,
                                                  64)
     records = []
     for batch in iterator:
         records += batch
コード例 #12
0
ファイル: dataloader.py プロジェクト: jhejna/ul_gen
    def reset(self, lod, batch_size):
        assert lod in self.filenames.keys()
        self.current_filenames = self.filenames[lod]
        self.batch_size = batch_size

        img_size = 2**lod

        if self.needs_labels:
            self.features = {
                # 'shape': db.FixedLenFeature([3], db.int64),
                'data':
                db.FixedLenFeature([self.channels, img_size, img_size],
                                   db.uint8),
                'label':
                db.FixedLenFeature([], db.int64)
            }
        else:
            self.features = {
                # 'shape': db.FixedLenFeature([3], db.int64),
                'data':
                db.FixedLenFeature([self.channels, img_size, img_size],
                                   db.uint8)
            }

        buffer_size = self.buffer_size_b // (self.channels * img_size *
                                             img_size)

        if self.seed is None:
            seed = np.uint64(time.time() * 1000)
        else:
            seed = self.seed
            self.logger.info('!' * 80)
            self.logger.info(
                '! Seed is used for to shuffle data in TFRecordsDataset!')
            self.logger.info('!' * 80)

        self.iterator = db.ParsedTFRecordsDatasetIterator(
            self.current_filenames,
            self.features,
            self.batch_size,
            buffer_size,
            seed=seed)
コード例 #13
0
    def test_parsing_single_record_inplace_with_uint8(self):
        features_alternative = {
            'data': db.FixedLenFeature([3, 32, 32], db.uint8)
        }

        parser = db.RecordParser(features_alternative, False)
        self.assertIsNotNone(parser)

        data = np.zeros([3, 32, 32], dtype=np.uint8)

        for record, image_gt in zip(self.records, self.images_gt):
            parser.parse_single_example_inplace(record, [data], 0)
            self.assertTrue(np.all(data == image_gt))
コード例 #14
0
    def test_parsing_single_record_raise2(self):
        features = {
            'data': db.FixedLenFeature([], db.int64)
        }

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        with self.assertRaises(RuntimeError) as context:
            for record, image_gt in zip(self.records, self.images_gt):
                shape, data = parser.parse_single_example(record)

        self.assertEqual('Feature: data. Data types don\'t match. Expected type: int64,  Feature is: string.', context.exception.args[0])
コード例 #15
0
    def test_parsing_single_record_raise(self):
        features = {
            'does_not_exist': db.FixedLenFeature([], db.string)
        }

        parser = db.RecordParser(features)
        self.assertIsNotNone(parser)

        with self.assertRaises(RuntimeError) as context:
            for record, image_gt in zip(self.records, self.images_gt):
                shape, data = parser.parse_single_example(record)

        self.assertEqual('Feature does_not_exist is required but could not be found.', context.exception.args[0])
コード例 #16
0
 def test_yielder_randomized_parallel():
     features = {
         #'shape': db.FixedLenFeature([3], db.int64),
         'data': db.FixedLenFeature([3, 256, 256], db.uint8)
     }
     parser = db.RecordParser(features, True)
     record_yielder = db.ParsedRecordYielderRandomized(
         parser, filenames, 64, 1, 0)
     records = []
     while True:
         try:
             records += record_yielder.next_n(32)
         except StopIteration:
             break