def load_file(self):
     reader = tfrecord_loader(self.files[self.file_index], None,
                              list(TFRECORD_KEYS))
     data = []
     for datum in reader:
         data.append(datum)
     return data
Exemple #2
0
 def samples_per_file(self, filename):
     reader = tfrecord_loader(filename,
                              None,
                              list(TFRECORD_KEYS))
     count = 0
     for _ in reader:
         count += 1
     return count
Exemple #3
0
 def load_file(self):
     reader = tfrecord_loader(self.files[self.file_index],
                              self.files[self.file_index].replace(".tfrecord", ".index"),
                              list(TFRECORD_KEYS),
                              self.shard)
     data = []
     for datum in reader:
         data.append([datum[key] for key in TFRECORD_KEYS])
     return data
Exemple #4
0
 def __iter__(self):
     worker_info = torch.utils.data.get_worker_info()
     if worker_info is not None:
         shard = worker_info.id, worker_info.num_workers
         np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
     else:
         shard = None
     it = reader.tfrecord_loader(self.data_path, self.index_path,
                                 self.description, shard)
     if self.shuffle_queue_size:
         it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
     return it
Exemple #5
0
 def __next__(self):
     try:
         datum = next(self.reader)
     except StopIteration:
         if self.file_index >= len(self.files):
             raise StopIteration
         self.reader = tfrecord_loader(
             self.files[self.file_index],
             self.files[self.file_index].replace(".tfrecord", ".index"),
             list(TFRECORD_KEYS), self.shard)
         self.file_index += 1
         datum = next(self.reader)
     datum = [datum[key] for key in TFRECORD_KEYS]
     return datum
Exemple #6
0
    def __iter__(self):
        train_type = 'train' if self.train else 'test'

        # Distribute shards among interleaved workers.
        worker_info = th.utils.data.get_worker_info()
        if worker_info is None:
            i0 = 0
            step = 1
        else:
            i0 = worker_info.id
            step = worker_info.num_workers

        for shard in self.shards[i0::step]:
            # Download blob and parse tfrecord.
            # TODO(ycho): Consider downloading in the background.

            fp = None
            try:
                if self.opts.local:
                    # Load local shard.
                    fp = open(str(shard), 'rb')
                else:
                    # Load shard from remote GS bucket.
                    blob = self.bucket.blob(shard)
                    content = blob.download_as_bytes()
                    fp = io.BytesIO(content)

                reader = tfrecord_loader(fp, None, self.opts.features, None)

                for i, example in enumerate(reader):
                    # Decode example into features format ...
                    features = decode(example)

                    # Broadcast class information.
                    features[Schema.CLASS] = th.full(
                        (int(features[Schema.INSTANCE_NUM]), ),
                        self._index_from_class(features[Schema.CLASS]))

                    # Transform output and return.
                    output = features
                    if self.xfm:
                        output = self.xfm(output)
                    yield output
            finally:
                # Ensure that the `fp` resource is properly released.
                if fp is not None:
                    fp.close()
Exemple #7
0
 def __iter__(self):
     worker_info = torch.utils.data.get_worker_info()
     if worker_info is not None:
         shard = worker_info.id, worker_info.num_workers
         np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
     else:
         shard = None
     it = reader.tfrecord_loader(
         data_path=self.data_path,
         index_path=self.index_path,
         description=self.description,
         shard=shard,
         sequence_description=self.sequence_description,
         compression_type=self.compression_type)
     if self.shuffle_queue_size:
         it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
     if self.transform:
         it = map(self.transform, it)
     return it