class PyTorchRayEstimator(OrcaRayEstimator): def __init__(self, *, model_creator, optimizer_creator, loss_creator=None, metrics=None, scheduler_creator=None, training_operator_cls=TrainingOperator, initialization_hook=None, config=None, scheduler_step_freq="batch", use_tqdm=False, backend="torch_distributed", workers_per_node=1): if config is not None and "batch_size" in config: raise Exception( "Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate/predict function of the estimator instead.") from zoo.orca.learn.pytorch.pytorch_ray_estimator import PyTorchRayEstimator self.estimator = PyTorchRayEstimator( model_creator=model_creator, optimizer_creator=optimizer_creator, loss_creator=loss_creator, metrics=metrics, scheduler_creator=scheduler_creator, training_operator_cls=training_operator_cls, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, use_tqdm=use_tqdm, backend=backend, workers_per_node=workers_per_node) def fit(self, data, epochs=1, batch_size=32, profile=False, reduce_results=True, info=None, feature_cols=None, label_cols=None): """ Trains a PyTorch model given training data for several epochs. Calls `TrainingOperator.train_epoch()` on N parallel workers simultaneously underneath the hood. :param data: An instance of SparkXShards, a Spark DataFrame or a function that takes config and batch_size as argument and returns a PyTorch DataLoader for training. :param epochs: The number of epochs to train the model. Default is 1. :param batch_size: The number of samples per batch for each worker. Default is 32. The total batch size would be workers_per_node*num_nodes. If your training data is a function, you can set batch_size to be the input batch_size of the function for the PyTorch DataLoader. :param profile: Boolean. Whether to return time stats for the training procedure. Default is False. :param reduce_results: Boolean. Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts for all workers. Default is True. :param info: An optional dictionary that can be passed to the TrainingOperator for train_epoch and train_batch. :param feature_cols: feature column names if data is Spark DataFrame. :param label_cols: label column names if data is Spark DataFrame. :return A list of dictionary of metrics for every training epoch. If reduce_results is False, this will return a nested list of metric dictionaries whose length will be equal to the total number of workers. You can also provide custom metrics by passing in a custom training_operator_cls when creating the Estimator. """ return self.estimator.train(data=data, epochs=epochs, batch_size=batch_size, profile=profile, reduce_results=reduce_results, info=info, feature_cols=feature_cols, label_cols=label_cols) def predict(self, data, batch_size=32, feature_cols=None, profile=False): """ Using this PyTorch model to make predictions on the data. :param data: An instance of SparkXShards or a Spark DataFrame :param batch_size: The number of samples per batch for each worker. Default is 32. :param profile: Boolean. Whether to return time stats for the training procedure. Default is False. :param feature_cols: feature column names if data is a Spark DataFrame. :return A SparkXShards that contains the predictions with key "prediction" in each shard """ return self.estimator.predict(data, batch_size=batch_size, feature_cols=feature_cols, profile=profile) def evaluate(self, data, batch_size=32, num_steps=None, profile=False, info=None, feature_cols=None, label_cols=None): """ Evaluates a PyTorch model given validation data. Note that only accuracy for classification with zero-based label is supported by default. You can override validate_batch in TrainingOperator for other metrics. Calls `TrainingOperator.validate()` on N parallel workers simultaneously underneath the hood. :param data: An instance of SparkXShards, a Spark DataFrame or a function that takes config and batch_size as argument and returns a PyTorch DataLoader for validation. :param batch_size: The number of samples per batch for each worker. Default is 32. The total batch size would be workers_per_node*num_nodes. If your validation data is a function, you can set batch_size to be the input batch_size of the function for the PyTorch DataLoader. :param num_steps: The number of batches to compute the validation results on. This corresponds to the number of times `TrainingOperator.validate_batch` is called. :param profile: Boolean. Whether to return time stats for the training procedure. Default is False. :param info: An optional dictionary that can be passed to the TrainingOperator for validate. :param feature_cols: feature column names if train data is Spark DataFrame. :param label_cols: label column names if train data is Spark DataFrame. :return A dictionary of metrics for the given data, including validation accuracy and loss. You can also provide custom metrics by passing in a custom training_operator_cls when creating the Estimator. """ return self.estimator.validate(data=data, batch_size=batch_size, num_steps=num_steps, profile=profile, info=info, feature_cols=feature_cols, label_cols=label_cols) def get_model(self): """ Returns the learned PyTorch model. :return: The learned PyTorch model. """ return self.estimator.get_model() def save(self, checkpoint): """ Saves the Estimator state (including model and optimizer) to the provided checkpoint path. :param checkpoint: (str) Path to target checkpoint file. :return: """ return self.estimator.save(checkpoint=checkpoint) def load(self, checkpoint): """ Loads the Estimator state (including model and optimizer) from the provided checkpoint. :param checkpoint: (str) Path to target checkpoint file. """ return self.estimator.load(checkpoint=checkpoint) def shutdown(self, force=False): """ Shuts down workers and releases resources. :return: """ return self.estimator.shutdown(force=force)
class PyTorchRayEstimatorWrapper(Estimator): def __init__(self, *, model_creator, optimizer_creator, loss_creator=None, scheduler_creator=None, training_operator_cls=TrainingOperator, initialization_hook=None, config=None, scheduler_step_freq="batch", use_tqdm=False, backend="pytorch", workers_per_node=1): from zoo.orca.learn.pytorch.pytorch_ray_estimator import PyTorchRayEstimator self.estimator = PyTorchRayEstimator( model_creator=model_creator, optimizer_creator=optimizer_creator, loss_creator=loss_creator, scheduler_creator=scheduler_creator, training_operator_cls=training_operator_cls, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, use_tqdm=use_tqdm, backend=backend, workers_per_node=workers_per_node) def fit(self, data, epochs=1, profile=False, reduce_results=True, info=None): """ :param data: (callable) a funtion that takes a config dict as input and return a data loader containing the training data. :param epochs: (int) Number of epochs to train the model :param profile: (bool) Returns time stats for the training procedure. :param reduce_results: (bool) Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts. :param info: (dict) Optional dictionary passed to the training operator for ``train_epoch`` and ``train_batch``. :return: (list) A list of stats whose length will be equal to ``epochs``. stats is a dictionary of metrics for training. You can provide custom metrics by passing in a custom ``training_operator_cls``. If ``reduce_results=False``, this will return a list of metric dictionaries whose length will be equal to ``num_workers``. """ return self.estimator.train(data_creator=data, epochs=epochs, profile=profile, reduce_results=reduce_results, info=info) def predict(self, data, **kwargs): pass def evaluate(self, data, num_steps=None, profile=False, info=None): """ :param data: (callable) a funtion that takes a config dict as input and return a data loader containing the validation data. :param num_steps: (int) Number of batches to compute update steps on. This corresponds also to the number of times ``TrainingOperator.validate_batch`` is called. :param profile: (bool) Returns time stats for the evaluation procedure. :param info: (dict) Optional dictionary passed to the training operator for `validate` and `validate_batch`. :return: A dictionary of metrics for validation. You can provide custom metrics by passing in a custom ``training_operator_cls``. """ return self.estimator.validate(data_creator=data, num_steps=num_steps, profile=profile, info=info) def get_model(self): """Returns the learned model(s).""" return self.estimator.get_model() def save(self, checkpoint): """Saves the Estimator state to the provided checkpoint path. :param checkpoint: (str) Path to target checkpoint file. """ return self.estimator.save(checkpoint=checkpoint) def load(self, checkpoint): """Loads the Estimator and all workers from the provided checkpoint. :param checkpoint: (str) Path to target checkpoint file. """ return self.estimator.load(checkpoint=checkpoint) def shutdown(self, force=False): """Shuts down workers and releases resources.""" return self.estimator.shutdown(force=force)