def load_data(data: MLDatasetType, ignore: Optional[Sequence[str]] = None, indices: Optional[Sequence[int]] = None, **kwargs): indices = indices or list(range(0, data.num_shards())) shards: List[pd.DataFrame] = [ pd.concat(data.get_shard(i), copy=False) for i in indices ] # Concat all shards local_df = pd.concat(shards, copy=False) if ignore: local_df = local_df[local_df.columns.difference(ignore)] return local_df
def get_n(data: MLDatasetType): return data.num_shards()
def to_torch( ds: MLDataset, world_size: int, world_rank: int, batch_size: int, collate_fn: Callable, shuffle: bool = False, shuffle_seed: int = None, local_rank: int = -1, prefer_node: str = None, prefetch: bool = False): """ Create DataLoader from a MLDataset :param ds: the MLDataset :param world_size: the world_size of distributed model training :param world_rank: create the DataLoader for the given world_rank :param batch_size: the batch_size of the DtaLoader :param collate_fn: the collate_fn that create tensors from a pandas DataFrame :param shuffle: whether shuffle each batch of data :param shuffle_seed: the shuffle seed :param local_rank: the node local rank. It must be provided if prefer_node is not None. :param prefer_node: the prefer node for create the MLDataset actor :param prefetch: prefetch the data of DataLoader with one thread :return: a pytorch DataLoader """ # pylint: disable=C0415 import torch from raydp.torch.torch_ml_dataset import PrefetchedDataLoader, TorchMLDataset num_shards = ds.num_shards() assert num_shards % world_size == 0, \ (f"The number shards of MLDataset({ds}) should be a multiple of " f"world_size({world_size})") multiplier = num_shards // world_size selected_ds = None if prefer_node is not None: assert 0 <= local_rank < world_size # get all actors # there should be only one actor_set because of select_shards() is not allowed # after union() def location_check(actor): address = ray.actors(actor._actor_id.hex())["Address"]["IPAddress"] return address == prefer_node actors = ds.actor_sets[0].actors actor_indexes = [i for i, actor in enumerate(actors) if location_check(actor)] if len(actor_indexes) % multiplier != 0: selected_ds = None logger.warning(f"We could not find enough shard actor in prefer " f"node({prefer_node}), fail back to normal select_shards(). " f"Found: ({actor_indexes}) which length is not multiple of " f"num_shards({num_shards}) // world_size({world_size}).") else: shard_ids = actor_indexes[local_rank: local_rank + multiplier] selected_ds = ds.select_shards(shard_ids) if selected_ds is None: shard_ids = [] i = world_rank step = world_size while i < num_shards: shard_ids.append(i) i += step selected_ds = ds.select_shards(shard_ids) selected_ds = selected_ds.batch(batch_size) torch_ds = TorchMLDataset(selected_ds, collate_fn, shuffle, shuffle_seed) data_loader = torch.utils.data.DataLoader(dataset=torch_ds, batch_size=None, batch_sampler=None, shuffle=False, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, sampler=None) if prefetch: data_loader = PrefetchedDataLoader(data_loader) return data_loader