def load_file(self): reader = tfrecord_loader(self.files[self.file_index], None, list(TFRECORD_KEYS)) data = [] for datum in reader: data.append(datum) return data
def samples_per_file(self, filename): reader = tfrecord_loader(filename, None, list(TFRECORD_KEYS)) count = 0 for _ in reader: count += 1 return count
def load_file(self): reader = tfrecord_loader(self.files[self.file_index], self.files[self.file_index].replace(".tfrecord", ".index"), list(TFRECORD_KEYS), self.shard) data = [] for datum in reader: data.append([datum[key] for key in TFRECORD_KEYS]) return data
def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: shard = worker_info.id, worker_info.num_workers np.random.seed(worker_info.seed % np.iinfo(np.uint32).max) else: shard = None it = reader.tfrecord_loader(self.data_path, self.index_path, self.description, shard) if self.shuffle_queue_size: it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size) return it
def __next__(self): try: datum = next(self.reader) except StopIteration: if self.file_index >= len(self.files): raise StopIteration self.reader = tfrecord_loader( self.files[self.file_index], self.files[self.file_index].replace(".tfrecord", ".index"), list(TFRECORD_KEYS), self.shard) self.file_index += 1 datum = next(self.reader) datum = [datum[key] for key in TFRECORD_KEYS] return datum
def __iter__(self): train_type = 'train' if self.train else 'test' # Distribute shards among interleaved workers. worker_info = th.utils.data.get_worker_info() if worker_info is None: i0 = 0 step = 1 else: i0 = worker_info.id step = worker_info.num_workers for shard in self.shards[i0::step]: # Download blob and parse tfrecord. # TODO(ycho): Consider downloading in the background. fp = None try: if self.opts.local: # Load local shard. fp = open(str(shard), 'rb') else: # Load shard from remote GS bucket. blob = self.bucket.blob(shard) content = blob.download_as_bytes() fp = io.BytesIO(content) reader = tfrecord_loader(fp, None, self.opts.features, None) for i, example in enumerate(reader): # Decode example into features format ... features = decode(example) # Broadcast class information. features[Schema.CLASS] = th.full( (int(features[Schema.INSTANCE_NUM]), ), self._index_from_class(features[Schema.CLASS])) # Transform output and return. output = features if self.xfm: output = self.xfm(output) yield output finally: # Ensure that the `fp` resource is properly released. if fp is not None: fp.close()
def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: shard = worker_info.id, worker_info.num_workers np.random.seed(worker_info.seed % np.iinfo(np.uint32).max) else: shard = None it = reader.tfrecord_loader( data_path=self.data_path, index_path=self.index_path, description=self.description, shard=shard, sequence_description=self.sequence_description, compression_type=self.compression_type) if self.shuffle_queue_size: it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size) if self.transform: it = map(self.transform, it) return it