コード例 #1
0
    def load_data(data: MLDatasetType,
                  ignore: Optional[Sequence[str]] = None,
                  indices: Optional[Sequence[int]] = None,
                  **kwargs):
        indices = indices or list(range(0, data.num_shards()))

        shards: List[pd.DataFrame] = [
            pd.concat(data.get_shard(i), copy=False) for i in indices
        ]

        # Concat all shards
        local_df = pd.concat(shards, copy=False)

        if ignore:
            local_df = local_df[local_df.columns.difference(ignore)]

        return local_df
コード例 #2
0
 def get_n(data: MLDatasetType):
     return data.num_shards()
コード例 #3
0
ファイル: dataset.py プロジェクト: oap-project/raydp
    def to_torch(
            ds: MLDataset,
            world_size: int,
            world_rank: int,
            batch_size: int,
            collate_fn: Callable,
            shuffle: bool = False,
            shuffle_seed: int = None,
            local_rank: int = -1,
            prefer_node: str = None,
            prefetch: bool = False):
        """
        Create DataLoader from a MLDataset
        :param ds: the MLDataset
        :param world_size: the world_size of distributed model training
        :param world_rank: create the DataLoader for the given world_rank
        :param batch_size: the batch_size of the DtaLoader
        :param collate_fn: the collate_fn that create tensors from a pandas DataFrame
        :param shuffle: whether shuffle each batch of data
        :param shuffle_seed: the shuffle seed
        :param local_rank: the node local rank. It must be provided if prefer_node is
               not None.
        :param prefer_node: the prefer node for create the MLDataset actor
        :param prefetch: prefetch the data of DataLoader with one thread
        :return: a pytorch DataLoader
        """
        # pylint: disable=C0415
        import torch
        from raydp.torch.torch_ml_dataset import PrefetchedDataLoader, TorchMLDataset

        num_shards = ds.num_shards()
        assert num_shards % world_size == 0, \
            (f"The number shards of MLDataset({ds}) should be a multiple of "
             f"world_size({world_size})")
        multiplier = num_shards // world_size

        selected_ds = None
        if prefer_node is not None:
            assert 0 <= local_rank < world_size

            # get all actors
            # there should be only one actor_set because of select_shards() is not allowed
            # after union()

            def location_check(actor):
                address = ray.actors(actor._actor_id.hex())["Address"]["IPAddress"]
                return address == prefer_node

            actors = ds.actor_sets[0].actors
            actor_indexes = [i for i, actor in enumerate(actors) if location_check(actor)]
            if len(actor_indexes) % multiplier != 0:
                selected_ds = None
                logger.warning(f"We could not find enough shard actor in prefer "
                               f"node({prefer_node}), fail back to normal select_shards(). "
                               f"Found: ({actor_indexes}) which length is not multiple of "
                               f"num_shards({num_shards}) // world_size({world_size}).")
            else:
                shard_ids = actor_indexes[local_rank: local_rank + multiplier]
                selected_ds = ds.select_shards(shard_ids)

        if selected_ds is None:
            shard_ids = []
            i = world_rank
            step = world_size
            while i < num_shards:
                shard_ids.append(i)
                i += step
            selected_ds = ds.select_shards(shard_ids)

        selected_ds = selected_ds.batch(batch_size)
        torch_ds = TorchMLDataset(selected_ds, collate_fn, shuffle, shuffle_seed)
        data_loader = torch.utils.data.DataLoader(dataset=torch_ds,
                                                  batch_size=None,
                                                  batch_sampler=None,
                                                  shuffle=False,
                                                  num_workers=0,
                                                  collate_fn=None,
                                                  pin_memory=False,
                                                  drop_last=False,
                                                  sampler=None)
        if prefetch:
            data_loader = PrefetchedDataLoader(data_loader)
        return data_loader