def train(self, data, epochs=1, batch_size=32, profile=False, reduce_results=True, info=None, feature_cols=None, label_cols=None): """ See the documentation in 'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.fit'. """ from zoo.orca.data import SparkXShards data, _ = maybe_dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=label_cols, mode="fit") if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data = process_xshards_of_pandas_dataframe( data, feature_cols, label_cols) from zoo.orca.data.utils import process_spark_xshards ray_xshards = process_spark_xshards(data, self.num_workers) def transform_func(worker, partition_refs): data_creator = partition_refs_to_creator(partition_refs) # Should not wrap DistributedSampler on DataLoader for SparkXShards input. return worker.train_epochs.remote(data_creator, epochs, batch_size, profile, info, False) worker_stats = ray_xshards.reduce_partitions_for_actors( self.remote_workers, transform_func) else: assert isinstance(data, types.FunctionType), \ "data should be either an instance of SparkXShards or a callable function, but " \ "got type: {}".format(type(data)) success, worker_stats = self._train_epochs(data, epochs=epochs, batch_size=batch_size, profile=profile, info=info) epoch_stats = list(map(list, zip(*worker_stats))) if reduce_results: for i in range(len(epoch_stats)): epoch_stats[i] = self._process_stats(epoch_stats[i]) return epoch_stats else: return epoch_stats
def validate(self, data, batch_size=32, num_steps=None, profile=False, info=None, feature_cols=None, label_cols=None): """ See the documentation in 'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.evaluate'. """ from zoo.orca.data import SparkXShards data, _ = maybe_dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=label_cols, mode="evaluate") if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data = process_xshards_of_pandas_dataframe( data, feature_cols, label_cols) from zoo.orca.data.utils import process_spark_xshards ray_xshards = process_spark_xshards(data, self.num_workers) def transform_func(worker, partition_refs): data_creator = partition_refs_to_creator(partition_refs) # Should not wrap DistributedSampler on DataLoader for SparkXShards input. return worker.validate.remote(data_creator, batch_size, num_steps, profile, info, False) worker_stats = ray_xshards.reduce_partitions_for_actors( self.remote_workers, transform_func) else: assert isinstance(data, types.FunctionType), \ "data should be either an instance of SparkXShards or a callable function, but " \ "got type: {}".format(type(data)) params = dict(data_creator=data, batch_size=batch_size, num_steps=num_steps, profile=profile, info=info) worker_stats = ray.get( [w.validate.remote(**params) for w in self.remote_workers]) return self._process_stats(worker_stats)
def train(self, data, epochs=1, batch_size=32, profile=False, reduce_results=True, info=None): """ See the documentation in 'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.fit'. """ from zoo.orca.data import SparkXShards if isinstance(data, SparkXShards): from zoo.orca.data.utils import process_spark_xshards ray_xshards = process_spark_xshards(data, self.num_workers) def transform_func(worker, shards_ref): data_creator = shards_ref_to_creator(shards_ref) # Should not wrap DistributedSampler on DataLoader for SparkXShards input. return worker.train_epochs.remote(data_creator, epochs, batch_size, profile, info, False) stats_shards = ray_xshards.transform_shards_with_actors( self.remote_workers, transform_func, gang_scheduling=True) worker_stats = stats_shards.collect_partitions() else: assert isinstance(data, types.FunctionType), \ "data should be either an instance of SparkXShards or a callable function, but " \ "got type: {}".format(type(data)) success, worker_stats = self._train_epochs(data, epochs=epochs, batch_size=batch_size, profile=profile, info=info) epoch_stats = list(map(list, zip(*worker_stats))) if reduce_results: for i in range(len(epoch_stats)): epoch_stats[i] = self._process_stats(epoch_stats[i]) return epoch_stats else: return epoch_stats
def validate(self, data, batch_size=32, num_steps=None, profile=False, info=None): """ See the documentation in 'zoo.orca.learn.pytorch.estimator.PyTorchRayEstimatorWrapper.evaluate'. """ from zoo.orca.data import SparkXShards if isinstance(data, SparkXShards): from zoo.orca.data.utils import process_spark_xshards ray_xshards = process_spark_xshards(data, self.num_workers) def transform_func(worker, shards_ref): data_creator = shards_ref_to_creator(shards_ref) # Should not wrap DistributedSampler on DataLoader for SparkXShards input. return worker.validate.remote(data_creator, batch_size, num_steps, profile, info, False) stats_shards = ray_xshards.transform_shards_with_actors( self.remote_workers, transform_func, gang_scheduling=True) worker_stats = stats_shards.collect_partitions() else: assert isinstance(data, types.FunctionType), \ "data should be either an instance of SparkXShards or a callable function, but " \ "got type: {}".format(type(data)) params = dict(data_creator=data, batch_size=batch_size, num_steps=num_steps, profile=profile, info=info) worker_stats = ray.get( [w.validate.remote(**params) for w in self.remote_workers]) return self._process_stats(worker_stats)
def fit(self, data, epochs=1, batch_size=32, validation_data=None, train_resize_batch_num=None): """ Trains an MXNet model given train_data (with val_data) for several epochs. :param data: An instance of SparkXShards or a function that takes config and kv as arguments and returns an MXNet DataIter/DataLoader for training. You can specify data related configurations for this function in the config argument above. kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank can be used in this function to split data for different workers if necessary. :param epochs: The number of epochs to train the MXNet model. Default is 1. :param batch_size: The number of samples per batch for each worker. Default is 32. :param validation_data: An instance of SparkXShards or a function that takes config and kv as arguments and returns an MXNet DataIter/DataLoader for validation. You can specify data related configurations for this function in the config argument above. kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank can be used in this function to split data for different workers if necessary. :param train_resize_batch_num: The number of batches per epoch to resize to. Default is None. You might need to specify this if the size of train_data for each worker varies. MXNet distributed training would crash when the first worker finishes the training if the workers have unbalanced training data. See this issue for more details: https://github.com/apache/incubator-mxnet/issues/17651 """ if validation_data: assert self.validation_metrics_creator,\ "Metrics not defined for validation, please specify validation_metrics_creator " \ "when creating the Estimator" from zoo.orca.data import SparkXShards if isinstance(data, SparkXShards): ray_xshards = process_spark_xshards(data, self.num_workers) if validation_data is None: def transform_func(worker, shards_ref): data_creator = shards_ref_to_creator(shards_ref, shuffle=True) return worker.train.remote(data_creator, epochs, batch_size, None, train_resize_batch_num) stats_shards = ray_xshards.transform_shards_with_actors( self.workers, transform_func, gang_scheduling=True) else: val_ray_xshards = process_spark_xshards( validation_data, self.num_workers) def zip_func(worker, this_shards_ref, that_shards_ref): data_creator = shards_ref_to_creator(this_shards_ref, shuffle=True) validation_data_creator = shards_ref_to_creator( that_shards_ref, shuffle=True) return worker.train.remote(data_creator, epochs, batch_size, validation_data_creator, train_resize_batch_num) stats_shards = ray_xshards.zip_shards_with_actors( val_ray_xshards, self.workers, zip_func, gang_scheduling=True) server_stats = [ server.train.remote(None, epochs, batch_size, None, train_resize_batch_num) for server in self.servers ] worker_stats = stats_shards.collect() server_stats = ray.get(server_stats) server_stats = list(itertools.chain.from_iterable(server_stats)) stats = worker_stats + server_stats else: # data_creator functions; should return Iter or DataLoader assert isinstance(data, types.FunctionType),\ "train_data should be either an instance of SparkXShards or a callable function" train_data_list = [data] * self.num_workers if validation_data: assert isinstance(validation_data, types.FunctionType),\ "val_data should be either an instance of SparkXShards or a callable function" val_data_list = [validation_data] * self.num_workers self.runners = self.workers + self.servers # For servers, data is not used and thus just input a None value. train_data_list += [None] * self.num_servers val_data_list += [None] * self.num_servers stats = ray.get([ runner.train.remote(train_data_list[i], epochs, batch_size, val_data_list[i], train_resize_batch_num) for i, runner in enumerate(self.runners) ]) stats = list(itertools.chain.from_iterable(stats)) return stats
def fit(self, data, epochs=1, batch_size=32, verbose=1, callbacks=None, validation_data=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1, data_config=None, feature_cols=None, label_cols=None): """ Train this tensorflow model with train data. :param data: train data. It can be XShards, Spark DataFrame or creator function which returns Iter or DataLoader. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of numpy arrays. :param epochs: Number of epochs to train the model. Default: 1. :param batch_size: Batch size used for training. Default: 32. :param verbose: Prints output of one model if true. :param callbacks: List of Keras compatible callbacks to apply during training. :param validation_data: validation data. Validation data type should be the same as train data. :param class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. :param steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. If `steps_pre_epoch` is `None`, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify the `step_per_epoch` argument. :param validation_steps: Total number of steps (batches of samples) to draw before stopping when performing validation at the end of every epoch. Default: None. :param validation_freq: Only relevant if validation data is provided. Integer of `collections_abc.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the end of the 1st, 2nd, and 10th epochs. :param data_config: An optional dictionary that can be passed to data creator function. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :return: """ params = dict(epochs=epochs, batch_size=batch_size, verbose=verbose, callbacks=callbacks, class_weight=class_weight, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq, data_config=data_config) from zoo.orca.data import SparkXShards data, validation_data = maybe_dataframe_to_xshards( data, validation_data, feature_cols, label_cols, mode="fit", num_workers=self.num_workers, accept_str_col=True) if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data, validation_data = process_xshards_of_pandas_dataframe( data, feature_cols, label_cols, validation_data, "fit") ray_xshards = process_spark_xshards(data, self.num_workers) if validation_data is None: def transform_func(worker, partition_refs): params["data_creator"] = make_data_creator(partition_refs) return worker.step.remote(**params) worker_stats = ray_xshards.reduce_partitions_for_actors( self.remote_workers, transform_func) else: val_ray_xshards = process_spark_xshards( validation_data, self.num_workers) def zip_func(worker, this_partition_refs, that_partition_refs): params["data_creator"] = make_data_creator( this_partition_refs) params["validation_data_creator"] = \ make_data_creator(that_partition_refs) return worker.step.remote(**params) worker_stats = ray_xshards.zip_reduce_shards_with_actors( val_ray_xshards, self.remote_workers, zip_func) else: params["data_creator"] = data params["validation_data_creator"] = validation_data params_list = [params] * self.num_workers worker_stats = ray.get([ self.remote_workers[i].step.remote(**params_list[i]) for i in range(self.num_workers) ]) worker_stats = list(itertools.chain.from_iterable(worker_stats)) stats = worker_stats[0].copy() return stats