def read_as_tfdataset(path, output_types, output_shapes=None, *args, **kwargs): """ return a orca.data.tf.data.Dataset :param path: :return: """ path, _ = pa_fs(path) import tensorflow as tf schema_path = os.path.join(path, "_orca_metadata") j_str = open_text(schema_path)[0] schema = decode_schema(j_str) def generator(): for root, dirs, files in os.walk(path): for name in dirs: if name.startswith("chunk="): chunk_path = os.path.join(path, name) pq_table = pq.read_table(chunk_path) df = decode_feature_type_ndarray( pq_table.to_pandas(), schema) for record in df.to_dict("records"): yield record dataset = tf.data.Dataset.from_generator(generator, output_types=output_types, output_shapes=output_shapes) return dataset
def read_as_tfdataset(path, output_types, config=None, output_shapes=None, *args, **kwargs): """ return a orca.data.tf.data.Dataset :param path: :return: """ path, _ = pa_fs(path) import tensorflow as tf schema_path = os.path.join(path, "_orca_metadata") j_str = open_text(schema_path)[0] schema = decode_schema(j_str) row_group = [] for root, dirs, files in os.walk(path): for name in dirs: if name.startswith("chunk="): chunk_path = os.path.join(path, name) row_group.append(chunk_path) dataset = ParquetIterable(row_group=row_group, schema=schema, num_shards=config.get("num_shards"), rank=config.get("rank")) return tf.data.Dataset.from_generator(dataset, output_types=output_types, output_shapes=output_shapes)
def read_as_dataloader(path, config=None, transforms=None, batch_size=1, *args, **kwargs): path, _ = pa_fs(path) import tensorflow as tf schema_path = os.path.join(path, "_orca_metadata") j_str = open_text(schema_path)[0] schema = decode_schema(j_str) row_group = [] for root, dirs, files in os.walk(path): for name in dirs: if name.startswith("chunk="): chunk_path = os.path.join(path, name) row_group.append(chunk_path) class ParquetIterableDataset(torch.utils.data.IterableDataset): def __init__(self, row_group, num_shards=None, rank=None, transforms=None): super(ParquetDataset).__init__() self.row_group = row_group # To get the indices we expect self.row_group.sort() self.num_shards = num_shards self.rank = rank self.datapiece = None self.transforms = transforms filter_row_group_indexed = [] if self.num_shards is None or self.rank is None: filter_row_group_indexed = [ index for index in list(range(len(self.row_group)))] else: assert self.num_shards <= len( self.row_group), "num_shards should be not larger than partitions." \ "but got num_shards {} with partitions {}." \ .format(self.num_shards, len(self.row_group)) assert self.rank < self.num_shards, \ "shard index should be included in [0,num_shard)," \ "but got rank {} with num_shard {}.".format( self.rank, self.num_shards) filter_row_group_indexed = [index for index in list(range(len(self.row_group))) if index % self.num_shards == self.rank] data_record = [] for select_chunk_path in [self.row_group[i] for i in filter_row_group_indexed]: pq_table = pq.read_table(select_chunk_path) df = decode_feature_type_ndarray(pq_table.to_pandas(), schema) data_record.extend(df.to_dict("records")) self.datapiece = data_record self.cur = 0 self.cur_tail = len(self.datapiece) def __iter__(self): return self def __next__(self): # move iter here so we can do transforms if self.cur < self.cur_tail: elem = self.datapiece[self.cur] self.cur += 1 if self.transforms: return self.transforms(elem) else: return elem else: raise StopIteration def worker_init_fn(w_id): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset iter_start = dataset.cur iter_end = dataset.cur_tail per_worker = int( math.ceil(iter_end - iter_start / float(worker_info.num_workers))) w_id = worker_info.id dataset.cur = iter_start + w_id * per_worker dataset.cur_tail = min(dataset.cur + per_worker, iter_end) dataset = ParquetIterableDataset( row_group=row_group, num_shards=config.get("num_shards"), rank=config.get("rank"), transforms=transforms) return torch.utils.data.DataLoader(dataset, num_workers=config.get("num_workers", 0), batch_size=batch_size, worker_init_fn=worker_init_fn)
def read_as_dataloader(path, config=None, transforms=None, batch_size=1, *args, **kwargs): path, _ = pa_fs(path) import tensorflow as tf import torch schema_path = os.path.join(path, "_orca_metadata") j_str = open_text(schema_path)[0] schema = decode_schema(j_str) row_group = [] for root, dirs, files in os.walk(path): for name in dirs: if name.startswith("chunk="): chunk_path = os.path.join(path, name) row_group.append(chunk_path) class ParquetIterableDataset(torch.utils.data.IterableDataset): def __init__(self, row_group, schema, num_shards=None, rank=None, transforms=None): super().__init__() self.iterator = ParquetIterable(row_group, schema, num_shards, rank, transforms) self.cur = self.iterator.cur self.cur_tail = self.iterator.cur_tail def __iter__(self): return self.iterator.__iter__() def __next__(self): self.iterator.__next__() def worker_init_fn(w_id): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset iter_start = dataset.cur iter_end = dataset.cur_tail per_worker = int( math.ceil(iter_end - iter_start / float(worker_info.num_workers))) w_id = worker_info.id dataset.cur = iter_start + w_id * per_worker dataset.cur_tail = min(dataset.cur + per_worker, iter_end) dataset = ParquetIterableDataset(row_group=row_group, schema=schema, num_shards=config.get("num_shards"), rank=config.get("rank"), transforms=transforms) return torch.utils.data.DataLoader(dataset, num_workers=config.get( "num_workers", 0), batch_size=batch_size, worker_init_fn=worker_init_fn)