Пример #1
0
 def __iter__(self):
     it = io_utils.multi_tfrecord_loader(self.data_pattern,
                                         self.index_pattern, self.splits,
                                         self.description)
     if self.shuffle_queue_size:
         it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
     return it
Пример #2
0
 def __iter__(self):
     worker_info = torch.utils.data.get_worker_info()
     if worker_info is not None:
         np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
     it = reader.multi_tfrecord_loader(self.data_pattern,
                                       self.index_pattern, self.splits,
                                       self.description)
     if self.shuffle_queue_size:
         it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
     return it
Пример #3
0
 def __iter__(self):
     worker_info = torch.utils.data.get_worker_info()
     if worker_info is not None:
         shard = worker_info.id, worker_info.num_workers
         np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
     else:
         shard = None
     it = reader.tfrecord_loader(self.data_path, self.index_path,
                                 self.description, shard)
     if self.shuffle_queue_size:
         it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
     return it
Пример #4
0
 def __iter__(self):
     worker_info = torch.utils.data.get_worker_info()
     if worker_info is not None:
         np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
     it = reader.multi_tfrecord_loader(
         data_pattern=self.data_pattern,
         index_pattern=self.index_pattern,
         splits=self.splits,
         description=self.description,
         sequence_description=self.sequence_description,
         compression_type=self.compression_type,
         infinite=self.infinite,
     )
     if self.shuffle_queue_size:
         it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
     if self.transform:
         it = map(self.transform, it)
     return it
Пример #5
0
 def __iter__(self):
     worker_info = torch.utils.data.get_worker_info()
     if worker_info is not None:
         shard = worker_info.id, worker_info.num_workers
         np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
     else:
         shard = None
     it = reader.tfrecord_loader(
         data_path=self.data_path,
         index_path=self.index_path,
         description=self.description,
         shard=shard,
         sequence_description=self.sequence_description,
         compression_type=self.compression_type)
     if self.shuffle_queue_size:
         it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
     if self.transform:
         it = map(self.transform, it)
     return it
Пример #6
0
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            shard = worker_info.id, worker_info.num_workers
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
        else:
            shard = None
        it = tfrecord.tfrecord_loader(self.data_path, self.index_path,
                                      self.description, shard)

        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)

        def transformedit(it):
            for item in it:
                # print(item)
                image = self.transform(
                    PIL.Image.open(BytesIO(item["image"])).convert("RGB"))
                label = item["label"][0]
                # print(image.shape, label)
                yield image, label

        it2 = transformedit(it)

        # print(it)
        # image = self.transform(PIL.Image.open(BytesIO(it["image"])))
        # label = it["label"]
        return it2


# loader = tfrecord.tfrecord_loader("/data/ImageNet/train.tfrecord", "/data/ImageNet/train.idx", {
#     "image": "byte",
#     "label": "int",

# })

# def recordloader():
#     with open(path, 'rb') as f:
#         img = Image.open(f)
#         return img.convert('RGB')