Beispiel #1
0
    def train(self,
              data,
              epochs=1,
              batch_size=32,
              profile=False,
              reduce_results=True,
              info=None,
              feature_cols=None,
              label_cols=None):
        """
        See the documentation in
        'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.fit'.
        """
        from zoo.orca.data import SparkXShards

        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="fit")

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)
            from zoo.orca.data.utils import process_spark_xshards
            ray_xshards = process_spark_xshards(data, self.num_workers)

            def transform_func(worker, partition_refs):
                data_creator = partition_refs_to_creator(partition_refs)
                # Should not wrap DistributedSampler on DataLoader for SparkXShards input.
                return worker.train_epochs.remote(data_creator, epochs,
                                                  batch_size, profile, info,
                                                  False)

            worker_stats = ray_xshards.reduce_partitions_for_actors(
                self.remote_workers, transform_func)
        else:
            assert isinstance(data, types.FunctionType), \
                "data should be either an instance of SparkXShards or a callable function, but " \
                "got type: {}".format(type(data))

            success, worker_stats = self._train_epochs(data,
                                                       epochs=epochs,
                                                       batch_size=batch_size,
                                                       profile=profile,
                                                       info=info)

        epoch_stats = list(map(list, zip(*worker_stats)))
        if reduce_results:
            for i in range(len(epoch_stats)):
                epoch_stats[i] = self._process_stats(epoch_stats[i])
            return epoch_stats
        else:
            return epoch_stats
Beispiel #2
0
    def validate(self,
                 data,
                 batch_size=32,
                 num_steps=None,
                 profile=False,
                 info=None,
                 feature_cols=None,
                 label_cols=None):
        """
        See the documentation in
        'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.evaluate'.
        """
        from zoo.orca.data import SparkXShards
        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             label_cols=label_cols,
                                             mode="evaluate")
        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)
            from zoo.orca.data.utils import process_spark_xshards
            ray_xshards = process_spark_xshards(data, self.num_workers)

            def transform_func(worker, partition_refs):
                data_creator = partition_refs_to_creator(partition_refs)
                # Should not wrap DistributedSampler on DataLoader for SparkXShards input.
                return worker.validate.remote(data_creator, batch_size,
                                              num_steps, profile, info, False)

            worker_stats = ray_xshards.reduce_partitions_for_actors(
                self.remote_workers, transform_func)
        else:
            assert isinstance(data, types.FunctionType), \
                "data should be either an instance of SparkXShards or a callable function, but " \
                "got type: {}".format(type(data))

            params = dict(data_creator=data,
                          batch_size=batch_size,
                          num_steps=num_steps,
                          profile=profile,
                          info=info)

            worker_stats = ray.get(
                [w.validate.remote(**params) for w in self.remote_workers])
        return self._process_stats(worker_stats)
    def train(self,
              data,
              epochs=1,
              batch_size=32,
              profile=False,
              reduce_results=True,
              info=None):
        """
        See the documentation in
        'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.fit'.
        """
        from zoo.orca.data import SparkXShards
        if isinstance(data, SparkXShards):
            from zoo.orca.data.utils import process_spark_xshards
            ray_xshards = process_spark_xshards(data, self.num_workers)

            def transform_func(worker, shards_ref):
                data_creator = shards_ref_to_creator(shards_ref)
                # Should not wrap DistributedSampler on DataLoader for SparkXShards input.
                return worker.train_epochs.remote(data_creator, epochs,
                                                  batch_size, profile, info,
                                                  False)

            stats_shards = ray_xshards.transform_shards_with_actors(
                self.remote_workers, transform_func, gang_scheduling=True)
            worker_stats = stats_shards.collect_partitions()
        else:
            assert isinstance(data, types.FunctionType), \
                "data should be either an instance of SparkXShards or a callable function, but " \
                "got type: {}".format(type(data))

            success, worker_stats = self._train_epochs(data,
                                                       epochs=epochs,
                                                       batch_size=batch_size,
                                                       profile=profile,
                                                       info=info)

        epoch_stats = list(map(list, zip(*worker_stats)))
        if reduce_results:
            for i in range(len(epoch_stats)):
                epoch_stats[i] = self._process_stats(epoch_stats[i])
            return epoch_stats
        else:
            return epoch_stats
    def validate(self,
                 data,
                 batch_size=32,
                 num_steps=None,
                 profile=False,
                 info=None):
        """
        See the documentation in
        'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.evaluate'.
        """
        from zoo.orca.data import SparkXShards
        if isinstance(data, SparkXShards):
            from zoo.orca.data.utils import process_spark_xshards
            ray_xshards = process_spark_xshards(data, self.num_workers)

            def transform_func(worker, shards_ref):
                data_creator = shards_ref_to_creator(shards_ref)
                # Should not wrap DistributedSampler on DataLoader for SparkXShards input.
                return worker.validate.remote(data_creator, batch_size,
                                              num_steps, profile, info, False)

            stats_shards = ray_xshards.transform_shards_with_actors(
                self.remote_workers, transform_func, gang_scheduling=True)
            worker_stats = stats_shards.collect_partitions()
        else:
            assert isinstance(data, types.FunctionType), \
                "data should be either an instance of SparkXShards or a callable function, but " \
                "got type: {}".format(type(data))

            params = dict(data_creator=data,
                          batch_size=batch_size,
                          num_steps=num_steps,
                          profile=profile,
                          info=info)

            worker_stats = ray.get(
                [w.validate.remote(**params) for w in self.remote_workers])
        return self._process_stats(worker_stats)
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            validation_data=None,
            train_resize_batch_num=None):
        """
        Trains an MXNet model given train_data (with val_data) for several epochs.

        :param data: An instance of SparkXShards or a function that takes config and kv as
        arguments and returns an MXNet DataIter/DataLoader for training.
        You can specify data related configurations for this function in the config argument above.
        kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank
        can be used in this function to split data for different workers if necessary.

        :param epochs: The number of epochs to train the MXNet model. Default is 1.

        :param batch_size: The number of samples per batch for each worker. Default is 32.

        :param validation_data: An instance of SparkXShards or a function that takes config and
        kv as arguments and returns an MXNet DataIter/DataLoader for validation.
        You can specify data related configurations for this function in the config argument above.
        kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank
        can be used in this function to split data for different workers if necessary.

        :param train_resize_batch_num: The number of batches per epoch to resize to.
        Default is None. You might need to specify this if the size of train_data for each
        worker varies. MXNet distributed training would crash when the first worker finishes
        the training if the workers have unbalanced training data.
        See this issue for more details: https://github.com/apache/incubator-mxnet/issues/17651
        """
        if validation_data:
            assert self.validation_metrics_creator,\
                "Metrics not defined for validation, please specify validation_metrics_creator " \
                "when creating the Estimator"
        from zoo.orca.data import SparkXShards
        if isinstance(data, SparkXShards):

            ray_xshards = process_spark_xshards(data, self.num_workers)

            if validation_data is None:

                def transform_func(worker, shards_ref):
                    data_creator = shards_ref_to_creator(shards_ref,
                                                         shuffle=True)

                    return worker.train.remote(data_creator, epochs,
                                               batch_size, None,
                                               train_resize_batch_num)

                stats_shards = ray_xshards.transform_shards_with_actors(
                    self.workers, transform_func, gang_scheduling=True)
            else:
                val_ray_xshards = process_spark_xshards(
                    validation_data, self.num_workers)

                def zip_func(worker, this_shards_ref, that_shards_ref):
                    data_creator = shards_ref_to_creator(this_shards_ref,
                                                         shuffle=True)
                    validation_data_creator = shards_ref_to_creator(
                        that_shards_ref, shuffle=True)
                    return worker.train.remote(data_creator, epochs,
                                               batch_size,
                                               validation_data_creator,
                                               train_resize_batch_num)

                stats_shards = ray_xshards.zip_shards_with_actors(
                    val_ray_xshards,
                    self.workers,
                    zip_func,
                    gang_scheduling=True)
            server_stats = [
                server.train.remote(None, epochs, batch_size, None,
                                    train_resize_batch_num)
                for server in self.servers
            ]
            worker_stats = stats_shards.collect()
            server_stats = ray.get(server_stats)
            server_stats = list(itertools.chain.from_iterable(server_stats))
            stats = worker_stats + server_stats

        else:  # data_creator functions; should return Iter or DataLoader
            assert isinstance(data, types.FunctionType),\
                "train_data should be either an instance of SparkXShards or a callable function"
            train_data_list = [data] * self.num_workers
            if validation_data:
                assert isinstance(validation_data, types.FunctionType),\
                    "val_data should be either an instance of SparkXShards or a callable function"
            val_data_list = [validation_data] * self.num_workers
            self.runners = self.workers + self.servers
            # For servers, data is not used and thus just input a None value.
            train_data_list += [None] * self.num_servers
            val_data_list += [None] * self.num_servers

            stats = ray.get([
                runner.train.remote(train_data_list[i], epochs, batch_size,
                                    val_data_list[i], train_resize_batch_num)
                for i, runner in enumerate(self.runners)
            ])
            stats = list(itertools.chain.from_iterable(stats))
        return stats
Beispiel #6
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            verbose=1,
            callbacks=None,
            validation_data=None,
            class_weight=None,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            data_config=None,
            feature_cols=None,
            label_cols=None):
        """
        Train this tensorflow model with train data.

        :param data: train data. It can be XShards, Spark DataFrame or creator function which
               returns Iter or DataLoader.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
        :param epochs: Number of epochs to train the model. Default: 1.
        :param batch_size: Batch size used for training. Default: 32.
        :param verbose: Prints output of one model if true.
        :param callbacks: List of Keras compatible callbacks to apply during training.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param class_weight: Optional dictionary mapping class indices (integers) to a weight
               (float) value, used for weighting the loss function. This can be useful to tell
               the model to "pay more attention" to samples from an under-represented class.
        :param steps_per_epoch: Total number of steps (batches of samples) before declaring one
               epoch finished and starting the next epoch. If `steps_pre_epoch` is `None`, the
               epoch will run until the input dataset is exhausted. When passing an infinitely
               repeating dataset, you must specify the `step_per_epoch` argument.
        :param validation_steps: Total number of steps (batches of samples) to draw before stopping
               when performing validation at the end of every epoch. Default: None.
        :param validation_freq: Only relevant if validation data is provided. Integer of
               `collections_abc.Container` instance (e.g. list, tuple, etc.). If an integer,
               specifies how many training epochs to run before a new validation run is performed,
               e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies
               the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
               validation at the end of the 1st, 2nd, and 10th epochs.
        :param data_config: An optional dictionary that can be passed to data creator function.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame or
               an XShards of Pandas DataFrame.
               Default: None.
        :return:
        """
        params = dict(epochs=epochs,
                      batch_size=batch_size,
                      verbose=verbose,
                      callbacks=callbacks,
                      class_weight=class_weight,
                      steps_per_epoch=steps_per_epoch,
                      validation_steps=validation_steps,
                      validation_freq=validation_freq,
                      data_config=data_config)

        from zoo.orca.data import SparkXShards
        data, validation_data = maybe_dataframe_to_xshards(
            data,
            validation_data,
            feature_cols,
            label_cols,
            mode="fit",
            num_workers=self.num_workers,
            accept_str_col=True)

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols, validation_data, "fit")
            ray_xshards = process_spark_xshards(data, self.num_workers)

            if validation_data is None:

                def transform_func(worker, partition_refs):
                    params["data_creator"] = make_data_creator(partition_refs)
                    return worker.step.remote(**params)

                worker_stats = ray_xshards.reduce_partitions_for_actors(
                    self.remote_workers, transform_func)
            else:
                val_ray_xshards = process_spark_xshards(
                    validation_data, self.num_workers)

                def zip_func(worker, this_partition_refs, that_partition_refs):
                    params["data_creator"] = make_data_creator(
                        this_partition_refs)
                    params["validation_data_creator"] = \
                        make_data_creator(that_partition_refs)
                    return worker.step.remote(**params)

                worker_stats = ray_xshards.zip_reduce_shards_with_actors(
                    val_ray_xshards, self.remote_workers, zip_func)
        else:
            params["data_creator"] = data
            params["validation_data_creator"] = validation_data
            params_list = [params] * self.num_workers

            worker_stats = ray.get([
                self.remote_workers[i].step.remote(**params_list[i])
                for i in range(self.num_workers)
            ])
            worker_stats = list(itertools.chain.from_iterable(worker_stats))
        stats = worker_stats[0].copy()
        return stats