예제 #1
0
    def fit(self,
            train_ds: MLDataset,
            evaluate_ds: Optional[MLDataset] = None,
            num_steps=None,
            profile=False,
            reduce_results=True,
            max_retries=3,
            info=None) -> NoReturn:
        super().fit(train_ds, evaluate_ds)
        train_ds = train_ds.batch(self._batch_size)
        train_tf_ds = self._create_tf_ds(train_ds)

        if evaluate_ds is not None:
            evaluate_ds = evaluate_ds.batch(self._batch_size)
            evaluate_tf_ds = self._create_tf_ds(evaluate_ds)
        else:
            evaluate_tf_ds = None

        self._create_trainer(train_tf_ds, evaluate_tf_ds)
        assert self._trainer is not None
        for i in range(self._num_epochs):
            stats = self._trainer.train(
                num_steps=num_steps,
                profile=profile,
                reduce_results=reduce_results,
                max_retries=max_retries,
                info=info)
            print(f"Epoch-{i}: {stats}")

        if evaluate_tf_ds is not None:
            print(self._trainer.validate(num_steps, profile, reduce_results, info))
예제 #2
0
def _can_load_distributed(source: Data) -> bool:
    """Returns True if it might be possible to use distributed data loading"""
    from xgboost_ray.data_sources.ml_dataset import MLDataset
    from xgboost_ray.data_sources.modin import Modin

    if isinstance(source, (int, float, bool)):
        return False
    elif MLDataset.is_data_type(source):
        return True
    elif Modin.is_data_type(source):
        return True
    elif isinstance(source, str):
        # Strings should point to files or URLs
        # Usually parquet files point to directories
        return source.endswith(".parquet")
    elif isinstance(source, Sequence):
        # Sequence of strings should point to files or URLs
        return isinstance(source[0], str)
    elif isinstance(source, Iterable):
        # If we get an iterable but not a sequence, the best we can do
        # is check if we have a known non-distributed object
        if isinstance(source, (pd.DataFrame, pd.Series, np.ndarray)):
            return False

    # Per default, allow distributed loading.
    return True
예제 #3
0
 def _create_tf_ds(self, ds: MLDataset) -> TorchMLDataset:
     return ds.to_torch(self._feature_columns,
                        self._feature_shapes,
                        self._feature_types,
                        self._label_column,
                        self._label_shape,
                        self._label_type)
예제 #4
0
    def fit(self,
            train_ds: MLDataset,
            evaluate_ds: Optional[MLDataset] = None) -> NoReturn:
        super().fit(train_ds, evaluate_ds)

        def model_creator(config):
            # https://github.com/ray-project/ray/issues/5914
            import tensorflow.keras as keras  # pylint: disable=C0415, W0404

            model: keras.Model = keras.models.model_from_json(
                self._serialized_model)
            optimizer = keras.optimizers.get(self._serialized_optimizer)
            loss = keras.losses.get(self._serialized_loss)
            metrics = [keras.metrics.get(m) for m in self._serialized_metrics]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            return model

        train_ds = train_ds.batch(self._batch_size)
        train_tf_ds = self._create_tf_ds(train_ds)

        if evaluate_ds is not None:
            evaluate_ds = evaluate_ds.batch(self._batch_size)
            evaluate_tf_ds = self._create_tf_ds(evaluate_ds)
        else:
            evaluate_tf_ds = None

        def data_creator(config):
            if "TF_CONFIG" in os.environ:
                tf_config = json.loads(os.environ["TF_CONFIG"])
                world_rank = tf_config["task"]["index"]
            else:
                world_rank = -1
            batch_size = config["batch_size"]
            get_shard_config = config.get("get_shard", {})
            if "shuffle" in config:
                get_shard_config["shuffle"] = config["shuffle"]
            train_data = train_tf_ds.get_shard(
                world_rank, **get_shard_config).repeat().batch(batch_size)
            options = tf.data.Options()
            options.experimental_distribute.auto_shard_policy = \
                tf.data.experimental.AutoShardPolicy.OFF
            train_data = train_data.with_options(options)
            evaluate_data = None
            if evaluate_tf_ds is not None:
                evaluate_data = evaluate_tf_ds.get_shard(
                    world_rank, **get_shard_config).batch(batch_size)
                evaluate_data = evaluate_data.with_options(options)
            return train_data, evaluate_data

        self._trainer = TFTrainer(model_creator=model_creator,
                                  data_creator=data_creator,
                                  num_replicas=self._num_workers,
                                  **self._extra_config)
        for i in range(self._num_epochs):
            stats = self._trainer.train()
            print(f"Epoch-{i}: {stats}")

        if evaluate_tf_ds is not None:
            print(self._trainer.validate())
예제 #5
0
    def load_data(data: MLDatasetType,
                  ignore: Optional[Sequence[str]] = None,
                  indices: Optional[Sequence[int]] = None,
                  **kwargs):
        indices = indices or list(range(0, data.num_shards()))

        shards: List[pd.DataFrame] = [
            pd.concat(data.get_shard(i), copy=False) for i in indices
        ]

        # Concat all shards
        local_df = pd.concat(shards, copy=False)

        if ignore:
            local_df = local_df[local_df.columns.difference(ignore)]

        return local_df
예제 #6
0
def _detect_distributed(source: Data) -> bool:
    """Returns True if we should try to use distributed data loading"""
    from xgboost_ray.data_sources.ml_dataset import MLDataset
    from xgboost_ray.data_sources.modin import Modin
    if not _can_load_distributed(source):
        return False
    if MLDataset.is_data_type(source):
        return True
    if Modin.is_data_type(source):
        return True
    if isinstance(source, Iterable) and not isinstance(source, str) and \
       not (isinstance(source, Sequence) and isinstance(source[0], str)):
        # This is an iterable but not a Sequence of strings, and not a
        # pandas dataframe, series, or numpy array.
        # Detect False per default, can be overridden by passing
        # `distributed=True` to the RayDMatrix object.
        return False

    # Otherwise, assume distributed loading is possible
    return True
예제 #7
0
 def get_n(data: MLDatasetType):
     return data.num_shards()
예제 #8
0
    def to_torch(
            ds: MLDataset,
            world_size: int,
            world_rank: int,
            batch_size: int,
            collate_fn: Callable,
            shuffle: bool = False,
            shuffle_seed: int = None,
            local_rank: int = -1,
            prefer_node: str = None,
            prefetch: bool = False):
        """
        Create DataLoader from a MLDataset
        :param ds: the MLDataset
        :param world_size: the world_size of distributed model training
        :param world_rank: create the DataLoader for the given world_rank
        :param batch_size: the batch_size of the DtaLoader
        :param collate_fn: the collate_fn that create tensors from a pandas DataFrame
        :param shuffle: whether shuffle each batch of data
        :param shuffle_seed: the shuffle seed
        :param local_rank: the node local rank. It must be provided if prefer_node is
               not None.
        :param prefer_node: the prefer node for create the MLDataset actor
        :param prefetch: prefetch the data of DataLoader with one thread
        :return: a pytorch DataLoader
        """
        # pylint: disable=C0415
        import torch
        from raydp.torch.torch_ml_dataset import PrefetchedDataLoader, TorchMLDataset

        num_shards = ds.num_shards()
        assert num_shards % world_size == 0, \
            (f"The number shards of MLDataset({ds}) should be a multiple of "
             f"world_size({world_size})")
        multiplier = num_shards // world_size

        selected_ds = None
        if prefer_node is not None:
            assert 0 <= local_rank < world_size

            # get all actors
            # there should be only one actor_set because of select_shards() is not allowed
            # after union()

            def location_check(actor):
                address = ray.actors(actor._actor_id.hex())["Address"]["IPAddress"]
                return address == prefer_node

            actors = ds.actor_sets[0].actors
            actor_indexes = [i for i, actor in enumerate(actors) if location_check(actor)]
            if len(actor_indexes) % multiplier != 0:
                selected_ds = None
                logger.warning(f"We could not find enough shard actor in prefer "
                               f"node({prefer_node}), fail back to normal select_shards(). "
                               f"Found: ({actor_indexes}) which length is not multiple of "
                               f"num_shards({num_shards}) // world_size({world_size}).")
            else:
                shard_ids = actor_indexes[local_rank: local_rank + multiplier]
                selected_ds = ds.select_shards(shard_ids)

        if selected_ds is None:
            shard_ids = []
            i = world_rank
            step = world_size
            while i < num_shards:
                shard_ids.append(i)
                i += step
            selected_ds = ds.select_shards(shard_ids)

        selected_ds = selected_ds.batch(batch_size)
        torch_ds = TorchMLDataset(selected_ds, collate_fn, shuffle, shuffle_seed)
        data_loader = torch.utils.data.DataLoader(dataset=torch_ds,
                                                  batch_size=None,
                                                  batch_sampler=None,
                                                  shuffle=False,
                                                  num_workers=0,
                                                  collate_fn=None,
                                                  pin_memory=False,
                                                  drop_last=False,
                                                  sampler=None)
        if prefetch:
            data_loader = PrefetchedDataLoader(data_loader)
        return data_loader