Esempio n. 1
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 num_steps=None,
                 profile=False,
                 info=None,
                 feature_cols=None,
                 label_cols=None):
        """
        Evaluates a PyTorch model given validation data.
        Note that only accuracy for classification with zero-based label is supported by
        default. You can override validate_batch in TrainingOperator for other metrics.
        Calls `TrainingOperator.validate()` on N parallel workers simultaneously
        underneath the hood.

        :param data: An instance of SparkXShards, a Spark DataFrame or a function that
               takes config and batch_size as argument and returns a PyTorch DataLoader for
               validation.
        :param batch_size: The number of samples per batch for each worker. Default is 32.
               The total batch size would be workers_per_node*num_nodes.
               If your validation data is a function, you can set batch_size to be the input
               batch_size of the function for the PyTorch DataLoader.
        :param num_steps: The number of batches to compute the validation results on. This
               corresponds to the number of times `TrainingOperator.validate_batch` is called.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param info: An optional dictionary that can be passed to the TrainingOperator
               for validate.
        :param feature_cols: feature column names if train data is Spark DataFrame.
        :param label_cols: label column names if train data is Spark DataFrame.

        :return: A dictionary of metrics for the given data, including validation accuracy and loss.
                You can also provide custom metrics by passing in a custom training_operator_cls
                when creating the Estimator.
        """
        from bigdl.orca.data import SparkXShards
        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="evaluate",
                                             num_workers=self.num_workers)
        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)
            from bigdl.orca.data.utils import process_spark_xshards
            ray_xshards = process_spark_xshards(data, self.num_workers)

            def transform_func(worker, partition_refs):
                data_creator = partition_refs_to_creator(partition_refs)
                # Should not wrap DistributedSampler on DataLoader for SparkXShards input.
                return worker.validate.remote(data_creator, batch_size,
                                              num_steps, profile, info, False)

            worker_stats = ray_xshards.reduce_partitions_for_actors(
                self.remote_workers, transform_func)
        else:
            assert isinstance(data, types.FunctionType), \
                "data should be either an instance of SparkXShards or a callable function, but " \
                "got type: {}".format(type(data))

            params = dict(data_creator=data,
                          batch_size=batch_size,
                          num_steps=num_steps,
                          profile=profile,
                          info=info)

            worker_stats = ray.get(
                [w.validate.remote(**params) for w in self.remote_workers])
        return self._process_stats(worker_stats)
Esempio n. 2
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            verbose=1,
            callbacks=None,
            validation_data=None,
            class_weight=None,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            data_config=None,
            feature_cols=None,
            label_cols=None,
            model_dir=None):
        """
        Train this tensorflow model with train data.
        :param data: train data. It can be XShards, Spark DataFrame or creator function which
               returns Iter or DataLoader.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
        :param epochs: Number of epochs to train the model. Default: 1.
        :param batch_size: Batch size used for training. Default: 32.
        :param verbose: Prints output of one model if true.
        :param callbacks: List of Keras compatible callbacks to apply during training.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param class_weight: Optional dictionary mapping class indices (integers) to a weight
               (float) value, used for weighting the loss function. This can be useful to tell
               the model to "pay more attention" to samples from an under-represented class.
        :return:
        """
        import numpy as np
        sc = OrcaContext.get_spark_context()

        init_params = dict(model_creator=self.model_creator,
                           compile_args_creator=self.compile_args_creator,
                           config=self.config,
                           verbose=self.verbose,
                           size=self.num_workers,
                           mode="fit",
                           cluster_info=self._get_cluster_info(sc),
                           model_dir=self.model_dir,
                           epoch=self.epoch)

        params = dict(epochs=epochs,
                      batch_size=batch_size,
                      verbose=verbose,
                      callbacks=callbacks,
                      class_weight=class_weight,
                      steps_per_epoch=steps_per_epoch,
                      validation_steps=validation_steps,
                      validation_freq=validation_freq,
                      data_config=data_config)

        # dataframe change to xshard, num_partition >= num_workers
        data, validation_data = maybe_dataframe_to_xshards(
            data,
            validation_data,
            feature_cols,
            label_cols,
            mode="fit",
            num_workers=self.num_workers,
            accept_str_col=True)

        if isinstance(data, SparkXShards):
            # set train/validation data
            if validation_data is None:

                def transform_func(iter, init_param, param):
                    partition_data = list(iter)
                    param["data_creator"] = make_data_creator(partition_data)
                    return SparkRunner(**init_param).step(**param)

                res = data.rdd.repartition(self.num_workers).barrier() \
                    .mapPartitions(
                        lambda iter: transform_func(iter, init_params, params)).collect()
            else:

                def transform_func(iter, init_param, param):
                    data_tuple_list = list(iter)
                    data_list = [x[0] for x in data_tuple_list]
                    valid_list = [x[1] for x in data_tuple_list]
                    param["data_creator"] = make_data_creator(data_list)
                    param["validation_data_creator"] = make_data_creator(
                        valid_list)
                    return SparkRunner(**init_param).step(**param)

                res = data.zip(validation_data).rdd.repartition(self.num_workers).barrier() \
                    .mapPartitions(
                        lambda iter: transform_func(iter, init_params, params)).collect()
        else:
            params["data_creator"] = data
            params["validation_data_creator"] = validation_data

            def transform_func(iter, init_param, param):
                return SparkRunner(**init_param).step(**param)

            res = self.workerRDD.barrier().mapPartitions(
                lambda iter: transform_func(iter, init_params, params
                                            )).collect()

        if self.model_dir:
            try:
                temp_dir = tempfile.mkdtemp()
                get_remote_file_to_local(os.path.join(self.model_dir,
                                                      "states.pkl"),
                                         os.path.join(temp_dir, "states.pkl"),
                                         over_write=True)
                import pickle
                with open(os.path.join(temp_dir, "states.pkl"), 'rb') as f:
                    states = pickle.load(f)
                    self.model_weights = states['weights']
                    self.epoch = states["epoch"]
            finally:
                shutil.rmtree(temp_dir)

        return res[0]
Esempio n. 3
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            profile=False,
            reduce_results=True,
            info=None,
            feature_cols=None,
            label_cols=None):
        """
        Trains a PyTorch model given training data for several epochs.
        Calls `TrainingOperator.train_epoch()` on N parallel workers simultaneously
        underneath the hood.

        :param data: An instance of SparkXShards, a Spark DataFrame or a function that
               takes config and batch_size as argument and returns a PyTorch DataLoader for
               training.
        :param epochs: The number of epochs to train the model. Default is 1.
        :param batch_size: The number of samples per batch for each worker. Default is 32.
               The total batch size would be workers_per_node*num_nodes.
               If your training data is a function, you can set batch_size to be the input
               batch_size of the function for the PyTorch DataLoader.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param reduce_results: Boolean. Whether to average all metrics across all workers into
               one dict. If a metric is a non-numerical value (or nested dictionaries), one value
               will be randomly selected among the workers. If False, returns a list of dicts for
               all workers.
               Default is True.
        :param info: An optional dictionary that can be passed to the TrainingOperator for
               train_epoch and train_batch.
        :param feature_cols: feature column names if data is Spark DataFrame.
        :param label_cols: label column names if data is Spark DataFrame.

        :return: A list of dictionary of metrics for every training epoch. If reduce_results is
                False, this will return a nested list of metric dictionaries whose length will be
                equal to the total number of workers.
                You can also provide custom metrics by passing in a custom training_operator_cls
                when creating the Estimator.
        """
        from bigdl.orca.data import SparkXShards

        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="fit",
                                             num_workers=self.num_workers)

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)
            from bigdl.orca.data.utils import process_spark_xshards
            ray_xshards = process_spark_xshards(data, self.num_workers)

            def transform_func(worker, partition_refs):
                data_creator = partition_refs_to_creator(partition_refs)
                # Should not wrap DistributedSampler on DataLoader for SparkXShards input.
                return worker.train_epochs.remote(data_creator, epochs,
                                                  batch_size, profile, info,
                                                  False)

            worker_stats = ray_xshards.reduce_partitions_for_actors(
                self.remote_workers, transform_func)
        else:
            assert isinstance(data, types.FunctionType), \
                "data should be either an instance of SparkXShards or a callable function, but " \
                "got type: {}".format(type(data))

            success, worker_stats = self._train_epochs(data,
                                                       epochs=epochs,
                                                       batch_size=batch_size,
                                                       profile=profile,
                                                       info=info)

        epoch_stats = list(map(list, zip(*worker_stats)))
        if reduce_results:
            for i in range(len(epoch_stats)):
                epoch_stats[i] = self._process_stats(epoch_stats[i])
            return epoch_stats
        else:
            return epoch_stats
Esempio n. 4
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 num_steps=None,
                 verbose=1,
                 sample_weight=None,
                 callbacks=None,
                 data_config=None,
                 feature_cols=None,
                 label_cols=None):
        """
        Evaluates the model on the validation data set.
        :param data: evaluate data. It can be XShards, Spark DataFrame or creator function which
               returns Iter or DataLoader.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param batch_size: Batch size used for evaluation. Default: 32.
        :param verbose: Prints output of one model if true.
        :param callbacks: List of Keras compatible callbacks to apply during evaluation.
        :param class_weight: Optional dictionary mapping class indices (integers) to a weight
               (float) value, used for weighting the loss function. This can be useful to tell
               the model to "pay more attention" to samples from an under-represented class.
        :return: validation result
        """
        import numpy as np
        sc = OrcaContext.get_spark_context()
        logger.info("Starting validation step.")

        if self.model_weights:
            weights = sc.broadcast(self.model_weights)
        else:
            weights = None

        init_params = dict(model_creator=self.model_creator,
                           compile_args_creator=self.compile_args_creator,
                           config=self.config,
                           verbose=self.verbose,
                           size=self.num_workers,
                           model_weights=weights,
                           mode="evaluate",
                           cluster_info=self._get_cluster_info(sc))

        params = dict(
            batch_size=batch_size,
            verbose=verbose,
            sample_weight=sample_weight,
            steps=num_steps,
            callbacks=callbacks,
            data_config=data_config,
        )

        # dataframe change to xshard, num_partition >= num_workers
        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="evaluate",
                                             num_workers=self.num_workers,
                                             accept_str_col=True)

        if isinstance(data, SparkXShards):
            # set train/validation data
            def transform_func(iter, init_param, param):
                partition_data = list(iter)
                param["data_creator"] = make_data_creator(partition_data)
                return SparkRunner(**init_param).validate(**param)

            res = data.rdd.repartition(self.num_workers).barrier() \
                .mapPartitions(lambda iter: transform_func(iter, init_params, params)).collect()
        else:
            params["data_creator"] = data

            def transform_func(iter, init_param, param):
                return SparkRunner(**init_param).validate(**param)

            res = self.workerRDD.barrier().mapPartitions(
                lambda iter: transform_func(iter, init_params, params
                                            )).collect()

        return res[0]
Esempio n. 5
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 num_steps=None,
                 profile=False,
                 info=None,
                 feature_cols=None,
                 label_cols=None):
        """
        Evaluates a PyTorch model given validation data.
        Note that only accuracy for classification with zero-based label is supported by
        default. You can override validate_batch in TrainingOperator for other metrics.
        Calls `TrainingOperator.validate()` on N parallel workers simultaneously
        underneath the hood.

        :param data: An instance of SparkXShards, a Spark DataFrame or a function that
               takes config and batch_size as argument and returns a PyTorch DataLoader for
               validation.
        :param batch_size: The number of samples per batch for each worker. Default is 32.
               The total batch size would be workers_per_node*num_nodes.
               If your validation data is a function, you can set batch_size to be the input
               batch_size of the function for the PyTorch DataLoader.
        :param num_steps: The number of batches to compute the validation results on. This
               corresponds to the number of times `TrainingOperator.validate_batch` is called.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param info: An optional dictionary that can be passed to the TrainingOperator
               for validate.
        :param feature_cols: feature column names if train data is Spark DataFrame.
        :param label_cols: label column names if train data is Spark DataFrame.

        :return: A dictionary of metrics for the given data, including validation accuracy and loss.
                You can also provide custom metrics by passing in a custom training_operator_cls
                when creating the Estimator.
        """
        init_params = dict(
            mode="evaluate",
            state_dict=self.state_dict,
        )
        init_params.update(self.worker_init_params)

        params = dict(
            batch_size=batch_size,
            num_steps=num_steps,
            profile=profile,
            info=info,
        )

        from bigdl.orca.data import SparkXShards
        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="evaluate",
                                             num_workers=self.num_workers)
        if isinstance(data, SparkXShards):
            # set train/validation data
            def transform_func(iter, init_param, param):
                partition_data = list(iter)
                param["data_creator"] = partition_to_creator(partition_data)
                return PytorchPysparkWorker(**init_param).validate(**param)

            res = data.rdd.repartition(self.num_workers).barrier() \
                .mapPartitions(lambda iter: transform_func(iter, init_params, params)).collect()
        else:
            params["data_creator"] = data

            def transform_func(iter, init_param, param):
                return PytorchPysparkWorker(**init_param).validate(**param)

            res = self.workerRDD.barrier().mapPartitions(
                lambda iter: transform_func(iter, init_params, params
                                            )).collect()

        return self._process_stats(res)
Esempio n. 6
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            profile=False,
            reduce_results=True,
            info=None,
            feature_cols=None,
            label_cols=None):
        """
        Trains a PyTorch model given training data for several epochs.
        Calls `TrainingOperator.train_epoch()` on N parallel workers simultaneously
        underneath the hood.

        :param data: An instance of SparkXShards, a Spark DataFrame or a function that
               takes config and batch_size as argument and returns a PyTorch DataLoader for
               training.
        :param epochs: The number of epochs to train the model. Default is 1.
        :param batch_size: The number of samples per batch for each worker. Default is 32.
               The total batch size would be workers_per_node*num_nodes.
               If your training data is a function, you can set batch_size to be the input
               batch_size of the function for the PyTorch DataLoader.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param reduce_results: Boolean. Whether to average all metrics across all workers into
               one dict. If a metric is a non-numerical value (or nested dictionaries), one value
               will be randomly selected among the workers. If False, returns a list of dicts for
               all workers.
               Default is True.
        :param info: An optional dictionary that can be passed to the TrainingOperator for
               train_epoch and train_batch.
        :param feature_cols: feature column names if data is Spark DataFrame.
        :param label_cols: label column names if data is Spark DataFrame.

        :return: A list of dictionary of metrics for every training epoch. If reduce_results is
                False, this will return a nested list of metric dictionaries whose length will be
                equal to the total number of workers.
                You can also provide custom metrics by passing in a custom training_operator_cls
                when creating the Estimator.
        """
        init_params = dict(mode="fit", state_dict=self.state_dict)
        init_params.update(self.worker_init_params)

        params = dict(
            epochs=epochs,
            batch_size=batch_size,
            profile=profile,
            info=info,
        )

        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="fit",
                                             num_workers=self.num_workers)

        if isinstance(data, SparkXShards):
            # set train/validation
            params["wrap_dataloader"] = False

            def transform_func(iter, init_params, param):
                partition_data = list(iter)
                param["data_creator"] = partition_to_creator(partition_data)
                runner = PytorchPysparkWorker(**init_params)
                result = runner.train_epochs(**param)
                runner.shutdown()
                return result

            res = data.rdd.repartition(self.num_workers).barrier() \
                .mapPartitions(
                lambda iter: transform_func(iter, init_params, params)).collect()

        else:
            assert isinstance(data, types.FunctionType), \
                "data should be either an instance of SparkXShards or a callable function, but " \
                "got type: {}".format(type(data))

            params["data_creator"] = data

            def transform_func(iter, init_param, param):
                return PytorchPysparkWorker(**init_param).train_epochs(**param)

            res = self.workerRDD.barrier().mapPartitions(
                lambda iter: transform_func(iter, init_params, params
                                            )).collect()

        self.state_dict = res[0][0]
        worker_stats = [re[1] for re in res]

        epoch_stats = list(map(list, zip(*worker_stats)))
        if reduce_results:
            for i in range(len(epoch_stats)):
                epoch_stats[i] = self._process_stats(epoch_stats[i])
            return epoch_stats
        else:
            return epoch_stats
Esempio n. 7
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 num_steps=None,
                 verbose=1,
                 sample_weight=None,
                 callbacks=None,
                 data_config=None,
                 feature_cols=None,
                 label_cols=None):
        """
        Evaluates the model on the validation data set.

        :param data: evaluate data. It can be XShards, Spark DataFrame or creator function which
               returns Iter or DataLoader.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
        :param batch_size: Batch size used for evaluation. Default: 32.
        :param num_steps: Total number of steps (batches of samples) before declaring the evaluation
               round finished. Ignored with the default value of `None`.
        :param verbose: Prints output of one model if true.
        :param sample_weight: Optional Numpy array of weights for the training samples, used for
               weighting the loss function. You can either pass a flat (1D) Numpy array with the
               same length as the input samples (1:1 mapping between weights and samples), or in
               the case of temporal data, you can pass a 2D array with shape (samples,
               sequence_length), to apply a different weight to every timestep of every sample.
        :param callbacks: List of Keras compatible callbacks to apply during evaluation.
        :param data_config: An optional dictionary that can be passed to data creator function.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame or
               an XShards of Pandas DataFrame.
               Default: None.
        :return: validation result
        """
        logger.info("Starting validation step.")
        params = dict(
            batch_size=batch_size,
            verbose=verbose,
            sample_weight=sample_weight,
            steps=num_steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from bigdl.orca.data import SparkXShards

        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="evaluate",
                                             num_workers=self.num_workers,
                                             accept_str_col=True)

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)

            data = data
            if data.num_partitions() != self.num_workers:
                data = data.repartition(self.num_workers)

            ray_xshards = RayXShards.from_spark_xshards(data)

            def transform_func(worker, partition_refs):
                params["data_creator"] = make_data_creator(partition_refs)
                return worker.validate.remote(**params)

            worker_stats = ray_xshards.reduce_partitions_for_actors(
                self.remote_workers, transform_func)
        else:  # data_creator functions; should return Iter or DataLoader
            params["data_creator"] = data
            params_list = [params] * self.num_workers

            worker_stats = ray.get([
                w.validate.remote(**params_list[i])
                for i, w in enumerate(self.remote_workers)
            ])
            worker_stats = list(itertools.chain.from_iterable(worker_stats))
        stats = worker_stats[0].copy()
        return stats
Esempio n. 8
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            verbose=1,
            callbacks=None,
            validation_data=None,
            class_weight=None,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            data_config=None,
            feature_cols=None,
            label_cols=None):
        """
        Train this tensorflow model with train data.

        :param data: train data. It can be XShards, Spark DataFrame or creator function which
               returns Iter or DataLoader.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
        :param epochs: Number of epochs to train the model. Default: 1.
        :param batch_size: Batch size used for training. Default: 32.
        :param verbose: Prints output of one model if true.
        :param callbacks: List of Keras compatible callbacks to apply during training.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param class_weight: Optional dictionary mapping class indices (integers) to a weight
               (float) value, used for weighting the loss function. This can be useful to tell
               the model to "pay more attention" to samples from an under-represented class.
        :param steps_per_epoch: Total number of steps (batches of samples) before declaring one
               epoch finished and starting the next epoch. If `steps_pre_epoch` is `None`, the
               epoch will run until the input dataset is exhausted. When passing an infinitely
               repeating dataset, you must specify the `step_per_epoch` argument.
        :param validation_steps: Total number of steps (batches of samples) to draw before stopping
               when performing validation at the end of every epoch. Default: None.
        :param validation_freq: Only relevant if validation data is provided. Integer of
               `collections_abc.Container` instance (e.g. list, tuple, etc.). If an integer,
               specifies how many training epochs to run before a new validation run is performed,
               e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies
               the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
               validation at the end of the 1st, 2nd, and 10th epochs.
        :param data_config: An optional dictionary that can be passed to data creator function.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame or
               an XShards of Pandas DataFrame.
               Default: None.
        :return:
        """
        params = dict(epochs=epochs,
                      batch_size=batch_size,
                      verbose=verbose,
                      callbacks=callbacks,
                      class_weight=class_weight,
                      steps_per_epoch=steps_per_epoch,
                      validation_steps=validation_steps,
                      validation_freq=validation_freq,
                      data_config=data_config)

        from bigdl.orca.data import SparkXShards
        data, validation_data = maybe_dataframe_to_xshards(
            data,
            validation_data,
            feature_cols,
            label_cols,
            mode="fit",
            num_workers=self.num_workers,
            accept_str_col=True)

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols, validation_data, "fit")
            ray_xshards = process_spark_xshards(data, self.num_workers)

            if validation_data is None:

                def transform_func(worker, partition_refs):
                    params["data_creator"] = make_data_creator(partition_refs)
                    return worker.step.remote(**params)

                worker_stats = ray_xshards.reduce_partitions_for_actors(
                    self.remote_workers, transform_func)
            else:
                val_ray_xshards = process_spark_xshards(
                    validation_data, self.num_workers)

                def zip_func(worker, this_partition_refs, that_partition_refs):
                    params["data_creator"] = make_data_creator(
                        this_partition_refs)
                    params["validation_data_creator"] = \
                        make_data_creator(that_partition_refs)
                    return worker.step.remote(**params)

                worker_stats = ray_xshards.zip_reduce_shards_with_actors(
                    val_ray_xshards, self.remote_workers, zip_func)
        else:
            params["data_creator"] = data
            params["validation_data_creator"] = validation_data
            params_list = [params] * self.num_workers

            worker_stats = ray.get([
                self.remote_workers[i].step.remote(**params_list[i])
                for i in range(self.num_workers)
            ])
            worker_stats = list(itertools.chain.from_iterable(worker_stats))
        stats = worker_stats[0].copy()
        return stats